diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java index 72cd1f8bc2..ef7a10f438 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java @@ -19,8 +19,7 @@ package org.springframework.security.messaging.access.expression; import org.springframework.expression.EvaluationContext; /** - * - * /** Allows post processing the {@link EvaluationContext} + * Allows post processing the {@link EvaluationContext} * *

* This API is intentionally kept package scope as it may evolve over time. diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java index 0fba7760b7..a819ce4cd3 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java @@ -38,6 +38,9 @@ import org.springframework.security.messaging.util.matcher.MessageMatcher; */ public final class ExpressionBasedMessageSecurityMetadataSourceFactory { + private ExpressionBasedMessageSecurityMetadataSourceFactory() { + } + /** * Create a {@link MessageSecurityMetadataSource} that uses {@link MessageMatcher} * mapped to Spring Expressions. Each entry is considered in order and only the first @@ -108,9 +111,7 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory { public static MessageSecurityMetadataSource createExpressionMessageMetadataSource( LinkedHashMap, String> matcherToExpression, SecurityExpressionHandler> handler) { - LinkedHashMap, Collection> matcherToAttrs = new LinkedHashMap<>(); - for (Map.Entry, String> entry : matcherToExpression.entrySet()) { MessageMatcher matcher = entry.getKey(); String rawExpression = entry.getValue(); @@ -121,7 +122,4 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory { return new DefaultMessageSecurityMetadataSource(matcherToAttrs); } - private ExpressionBasedMessageSecurityMetadataSourceFactory() { - } - } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java index 598ca2e339..e663c4f06a 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java @@ -69,7 +69,7 @@ class MessageExpressionConfigAttribute implements ConfigAttribute, EvaluationCon @Override public EvaluationContext postProcess(EvaluationContext ctx, Message message) { if (this.matcher instanceof SimpDestinationMessageMatcher) { - final Map variables = ((SimpDestinationMessageMatcher) this.matcher) + Map variables = ((SimpDestinationMessageMatcher) this.matcher) .extractPathVariables(message); for (Map.Entry entry : variables.entrySet()) { ctx.setVariable(entry.getKey(), entry.getValue()); diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java index 3ec994cd77..b097df8c1e 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java @@ -44,19 +44,15 @@ public class MessageExpressionVoter implements AccessDecisionVoter @Override public int vote(Authentication authentication, Message message, Collection attributes) { - assert authentication != null; - assert message != null; - assert attributes != null; - + Assert.notNull(authentication, "authentication must not be null"); + Assert.notNull(message, "message must not be null"); + Assert.notNull(attributes, "attributes must not be null"); MessageExpressionConfigAttribute attr = findConfigAttribute(attributes); - if (attr == null) { return ACCESS_ABSTAIN; } - EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, message); ctx = attr.postProcess(ctx, message); - return ExpressionUtils.evaluateAsBoolean(attr.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED; } diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java index c1f3057c87..6e3eb8ba41 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java @@ -65,11 +65,9 @@ public final class DefaultMessageSecurityMetadataSource implements MessageSecuri @Override public Collection getAllConfigAttributes() { Set allAttributes = new HashSet<>(); - for (Collection entry : this.messageMap.values()) { allAttributes.addAll(entry); } - return allAttributes; } diff --git a/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java index 031752ffef..58cd4f720e 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java +++ b/messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java @@ -98,26 +98,20 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet return null; } Object principal = authentication.getPrincipal(); - AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter); - String expressionToParse = authPrincipal.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(principal); context.setVariable("this", principal); - Expression expression = this.parser.parseExpression(expressionToParse); principal = expression.getValue(context); } - if (principal != null && !parameter.getParameterType().isAssignableFrom(principal.getClass())) { if (authPrincipal.errorOnInvalidType()) { throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); } - else { - return null; - } + return null; } return principal; } diff --git a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java index 409043edcb..594cfcacba 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java @@ -43,9 +43,9 @@ import org.springframework.util.Assert; public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter implements ExecutorChannelInterceptor { - private final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext(); + private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext(); - private static final ThreadLocal> ORIGINAL_CONTEXT = new ThreadLocal<>(); + private static final ThreadLocal> originalContext = new ThreadLocal<>(); private final String authenticationHeaderName; @@ -110,46 +110,41 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA private void setup(Message message) { SecurityContext currentContext = SecurityContextHolder.getContext(); - - Stack contextStack = ORIGINAL_CONTEXT.get(); + Stack contextStack = originalContext.get(); if (contextStack == null) { contextStack = new Stack<>(); - ORIGINAL_CONTEXT.set(contextStack); + originalContext.set(contextStack); } contextStack.push(currentContext); - Object user = message.getHeaders().get(this.authenticationHeaderName); - - Authentication authentication; - if ((user instanceof Authentication)) { - authentication = (Authentication) user; - } - else { - authentication = this.anonymous; - } + Authentication authentication = getAuthentication(user); SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); SecurityContextHolder.setContext(context); } - private void cleanup() { - Stack contextStack = ORIGINAL_CONTEXT.get(); + private Authentication getAuthentication(Object user) { + if ((user instanceof Authentication)) { + return (Authentication) user; + } + return this.anonymous; + } + private void cleanup() { + Stack contextStack = originalContext.get(); if (contextStack == null || contextStack.isEmpty()) { SecurityContextHolder.clearContext(); - ORIGINAL_CONTEXT.remove(); + originalContext.remove(); return; } - - SecurityContext originalContext = contextStack.pop(); - + SecurityContext context = contextStack.pop(); try { - if (this.EMPTY_CONTEXT.equals(originalContext)) { + if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) { SecurityContextHolder.clearContext(); - ORIGINAL_CONTEXT.remove(); + originalContext.remove(); } else { - SecurityContextHolder.setContext(originalContext); + SecurityContextHolder.setContext(context); } } catch (Throwable ex) { diff --git a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java index 237b79cc13..a857232d44 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java +++ b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java @@ -134,28 +134,21 @@ public class AuthenticationPrincipalArgumentResolver implements HandlerMethodArg private Object resolvePrincipal(MethodParameter parameter, Object principal) { AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter); - String expressionToParse = authPrincipal.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(principal); context.setVariable("this", principal); context.setBeanResolver(this.beanResolver); - Expression expression = this.parser.parseExpression(expressionToParse); principal = expression.getValue(context); } - if (isInvalidType(parameter, principal)) { - if (authPrincipal.errorOnInvalidType()) { throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); } - else { - return null; - } + return null; } - return principal; } diff --git a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java index 4b835e3f97..89491d2b17 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java +++ b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java @@ -133,28 +133,21 @@ public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgu private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) { CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter); - String expressionToParse = contextAnno.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(securityContext); context.setVariable("this", securityContext); context.setBeanResolver(this.beanResolver); - Expression expression = this.parser.parseExpression(expressionToParse); securityContext = expression.getValue(context); } - if (isInvalidType(parameter, securityContext)) { - if (contextAnno.errorOnInvalidType()) { throw new ClassCastException(securityContext + " is not assignable to " + parameter.getParameterType()); } - else { - return null; - } + return null; } - return securityContext; } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java index fc1ca1f2c0..bf899ba1df 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java @@ -31,7 +31,13 @@ import org.springframework.util.Assert; */ public abstract class AbstractMessageMatcherComposite implements MessageMatcher { - protected final Log LOGGER = LogFactory.getLog(getClass()); + protected final Log logger = LogFactory.getLog(getClass()); + + /** + * @deprecated since 5.4 in favor of {@link #logger} + */ + @Deprecated + protected final Log LOGGER = this.logger; private final List> messageMatchers; @@ -41,9 +47,7 @@ public abstract class AbstractMessageMatcherComposite implements MessageMatch */ AbstractMessageMatcherComposite(List> messageMatchers) { Assert.notEmpty(messageMatchers, "messageMatchers must contain a value"); - if (messageMatchers.contains(null)) { - throw new IllegalArgumentException("messageMatchers cannot contain null values"); - } + Assert.isTrue(!messageMatchers.contains(null), "messageMatchers cannot contain null values"); this.messageMatchers = messageMatchers; } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java index 45767fc56e..6edc0c4ef4 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java @@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher; import java.util.List; +import org.springframework.core.log.LogMessage; import org.springframework.messaging.Message; /** @@ -49,15 +50,13 @@ public final class AndMessageMatcher extends AbstractMessageMatcherComposite< @Override public boolean matches(Message message) { for (MessageMatcher matcher : getMessageMatchers()) { - if (this.LOGGER.isDebugEnabled()) { - this.LOGGER.debug("Trying to match using " + matcher); - } + this.logger.debug(LogMessage.format("Trying to match using %s", matcher)); if (!matcher.matches(message)) { - this.LOGGER.debug("Did not match"); + this.logger.debug("Did not match"); return false; } } - this.LOGGER.debug("All messageMatchers returned true"); + this.logger.debug("All messageMatchers returned true"); return true; } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java index 50d401ba3b..ffafb72a6a 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java @@ -26,17 +26,11 @@ import org.springframework.messaging.Message; */ public interface MessageMatcher { - /** - * Returns true if the {@link Message} matches, else false - * @param message the {@link Message} to match on - * @return true if the {@link Message} matches, else false - */ - boolean matches(Message message); - /** * Matches every {@link Message} */ MessageMatcher ANY_MESSAGE = new MessageMatcher() { + @Override public boolean matches(Message message) { return true; @@ -46,6 +40,14 @@ public interface MessageMatcher { public String toString() { return "ANY_MESSAGE"; } + }; + /** + * Returns true if the {@link Message} matches, else false + * @param message the {@link Message} to match on + * @return true if the {@link Message} matches, else false + */ + boolean matches(Message message); + } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java index fb9971cde3..010fe7aecf 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java @@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher; import java.util.List; +import org.springframework.core.log.LogMessage; import org.springframework.messaging.Message; /** @@ -49,15 +50,13 @@ public final class OrMessageMatcher extends AbstractMessageMatcherComposite message) { for (MessageMatcher matcher : getMessageMatchers()) { - if (this.LOGGER.isDebugEnabled()) { - this.LOGGER.debug("Trying to match using " + matcher); - } + this.logger.debug(LogMessage.format("Trying to match using %s", matcher)); if (matcher.matches(message)) { - this.LOGGER.debug("matched"); + this.logger.debug("matched"); return true; } } - this.LOGGER.debug("No matches found"); + this.logger.debug("No matches found"); return false; } diff --git a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java index 39a52b3beb..d4ae0e15d6 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java +++ b/messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java @@ -107,11 +107,8 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher "SimpMessageType " + type + " does not contain a destination and so cannot be matched on."); this.matcher = pathMatcher; this.messageTypeMatcher = (type != null) ? new SimpMessageTypeMatcher(type) : ANY_MESSAGE; this.pattern = pattern; @@ -122,7 +119,6 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher { public boolean matches(Message message) { MessageHeaders headers = message.getHeaders(); SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); - return this.typeToMatch == messageType; } @@ -63,7 +62,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher { } SimpMessageTypeMatcher otherMatcher = (SimpMessageTypeMatcher) other; return ObjectUtils.nullSafeEquals(this.typeToMatch, otherMatcher.typeToMatch); - } @Override diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java index e9b26e7f2f..059b34bddb 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java @@ -46,23 +46,19 @@ public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter { if (!this.matcher.matches(message)) { return message; } - Map sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders()); CsrfToken expectedToken = (sessionAttributes != null) ? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null; - if (expectedToken == null) { throw new MissingCsrfTokenException(null); } - String actualTokenValue = SimpMessageHeaderAccessor.wrap(message) .getFirstNativeHeader(expectedToken.getHeaderName()); - boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue); - if (csrfCheckPassed) { - return message; + if (!csrfCheckPassed) { + throw new InvalidCsrfTokenException(expectedToken, actualTokenValue); } - throw new InvalidCsrfTokenException(expectedToken, actualTokenValue); + return message; } }