Polish spring-security-messaging main code

Manually polish `spring-security-messaging` following the formatting
and checkstyle fixes.

Issue gh-8945
This commit is contained in:
Phillip Webb 2020-07-31 13:46:39 -07:00 committed by Rob Winch
parent 554ef627fb
commit ad1dbf425f
16 changed files with 60 additions and 103 deletions

View File

@ -19,8 +19,7 @@ package org.springframework.security.messaging.access.expression;
import org.springframework.expression.EvaluationContext;
/**
*
* /** Allows post processing the {@link EvaluationContext}
* Allows post processing the {@link EvaluationContext}
*
* <p>
* This API is intentionally kept package scope as it may evolve over time.

View File

@ -38,6 +38,9 @@ import org.springframework.security.messaging.util.matcher.MessageMatcher;
*/
public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
private ExpressionBasedMessageSecurityMetadataSourceFactory() {
}
/**
* Create a {@link MessageSecurityMetadataSource} that uses {@link MessageMatcher}
* mapped to Spring Expressions. Each entry is considered in order and only the first
@ -108,9 +111,7 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
public static MessageSecurityMetadataSource createExpressionMessageMetadataSource(
LinkedHashMap<MessageMatcher<?>, String> matcherToExpression,
SecurityExpressionHandler<Message<Object>> handler) {
LinkedHashMap<MessageMatcher<?>, Collection<ConfigAttribute>> matcherToAttrs = new LinkedHashMap<>();
for (Map.Entry<MessageMatcher<?>, String> entry : matcherToExpression.entrySet()) {
MessageMatcher<?> matcher = entry.getKey();
String rawExpression = entry.getValue();
@ -121,7 +122,4 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
return new DefaultMessageSecurityMetadataSource(matcherToAttrs);
}
private ExpressionBasedMessageSecurityMetadataSourceFactory() {
}
}

View File

@ -69,7 +69,7 @@ class MessageExpressionConfigAttribute implements ConfigAttribute, EvaluationCon
@Override
public EvaluationContext postProcess(EvaluationContext ctx, Message<?> message) {
if (this.matcher instanceof SimpDestinationMessageMatcher) {
final Map<String, String> variables = ((SimpDestinationMessageMatcher) this.matcher)
Map<String, String> variables = ((SimpDestinationMessageMatcher) this.matcher)
.extractPathVariables(message);
for (Map.Entry<String, String> entry : variables.entrySet()) {
ctx.setVariable(entry.getKey(), entry.getValue());

View File

@ -44,19 +44,15 @@ public class MessageExpressionVoter<T> implements AccessDecisionVoter<Message<T>
@Override
public int vote(Authentication authentication, Message<T> message, Collection<ConfigAttribute> attributes) {
assert authentication != null;
assert message != null;
assert attributes != null;
Assert.notNull(authentication, "authentication must not be null");
Assert.notNull(message, "message must not be null");
Assert.notNull(attributes, "attributes must not be null");
MessageExpressionConfigAttribute attr = findConfigAttribute(attributes);
if (attr == null) {
return ACCESS_ABSTAIN;
}
EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, message);
ctx = attr.postProcess(ctx, message);
return ExpressionUtils.evaluateAsBoolean(attr.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED;
}

View File

@ -65,11 +65,9 @@ public final class DefaultMessageSecurityMetadataSource implements MessageSecuri
@Override
public Collection<ConfigAttribute> getAllConfigAttributes() {
Set<ConfigAttribute> allAttributes = new HashSet<>();
for (Collection<ConfigAttribute> entry : this.messageMap.values()) {
allAttributes.addAll(entry);
}
return allAttributes;
}

View File

@ -98,26 +98,20 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
return null;
}
Object principal = authentication.getPrincipal();
AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter);
String expressionToParse = authPrincipal.expression();
if (StringUtils.hasLength(expressionToParse)) {
StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(principal);
context.setVariable("this", principal);
Expression expression = this.parser.parseExpression(expressionToParse);
principal = expression.getValue(context);
}
if (principal != null && !parameter.getParameterType().isAssignableFrom(principal.getClass())) {
if (authPrincipal.errorOnInvalidType()) {
throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
}
else {
return null;
}
return null;
}
return principal;
}

View File

@ -43,9 +43,9 @@ import org.springframework.util.Assert;
public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter
implements ExecutorChannelInterceptor {
private final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
private static final ThreadLocal<Stack<SecurityContext>> ORIGINAL_CONTEXT = new ThreadLocal<>();
private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>();
private final String authenticationHeaderName;
@ -110,46 +110,41 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA
private void setup(Message<?> message) {
SecurityContext currentContext = SecurityContextHolder.getContext();
Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get();
Stack<SecurityContext> contextStack = originalContext.get();
if (contextStack == null) {
contextStack = new Stack<>();
ORIGINAL_CONTEXT.set(contextStack);
originalContext.set(contextStack);
}
contextStack.push(currentContext);
Object user = message.getHeaders().get(this.authenticationHeaderName);
Authentication authentication;
if ((user instanceof Authentication)) {
authentication = (Authentication) user;
}
else {
authentication = this.anonymous;
}
Authentication authentication = getAuthentication(user);
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
}
private void cleanup() {
Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get();
private Authentication getAuthentication(Object user) {
if ((user instanceof Authentication)) {
return (Authentication) user;
}
return this.anonymous;
}
private void cleanup() {
Stack<SecurityContext> contextStack = originalContext.get();
if (contextStack == null || contextStack.isEmpty()) {
SecurityContextHolder.clearContext();
ORIGINAL_CONTEXT.remove();
originalContext.remove();
return;
}
SecurityContext originalContext = contextStack.pop();
SecurityContext context = contextStack.pop();
try {
if (this.EMPTY_CONTEXT.equals(originalContext)) {
if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) {
SecurityContextHolder.clearContext();
ORIGINAL_CONTEXT.remove();
originalContext.remove();
}
else {
SecurityContextHolder.setContext(originalContext);
SecurityContextHolder.setContext(context);
}
}
catch (Throwable ex) {

View File

@ -134,28 +134,21 @@ public class AuthenticationPrincipalArgumentResolver implements HandlerMethodArg
private Object resolvePrincipal(MethodParameter parameter, Object principal) {
AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter);
String expressionToParse = authPrincipal.expression();
if (StringUtils.hasLength(expressionToParse)) {
StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(principal);
context.setVariable("this", principal);
context.setBeanResolver(this.beanResolver);
Expression expression = this.parser.parseExpression(expressionToParse);
principal = expression.getValue(context);
}
if (isInvalidType(parameter, principal)) {
if (authPrincipal.errorOnInvalidType()) {
throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
}
else {
return null;
}
return null;
}
return principal;
}

View File

@ -133,28 +133,21 @@ public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgu
private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) {
CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter);
String expressionToParse = contextAnno.expression();
if (StringUtils.hasLength(expressionToParse)) {
StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(securityContext);
context.setVariable("this", securityContext);
context.setBeanResolver(this.beanResolver);
Expression expression = this.parser.parseExpression(expressionToParse);
securityContext = expression.getValue(context);
}
if (isInvalidType(parameter, securityContext)) {
if (contextAnno.errorOnInvalidType()) {
throw new ClassCastException(securityContext + " is not assignable to " + parameter.getParameterType());
}
else {
return null;
}
return null;
}
return securityContext;
}

View File

@ -31,7 +31,13 @@ import org.springframework.util.Assert;
*/
public abstract class AbstractMessageMatcherComposite<T> implements MessageMatcher<T> {
protected final Log LOGGER = LogFactory.getLog(getClass());
protected final Log logger = LogFactory.getLog(getClass());
/**
* @deprecated since 5.4 in favor of {@link #logger}
*/
@Deprecated
protected final Log LOGGER = this.logger;
private final List<MessageMatcher<T>> messageMatchers;
@ -41,9 +47,7 @@ public abstract class AbstractMessageMatcherComposite<T> implements MessageMatch
*/
AbstractMessageMatcherComposite(List<MessageMatcher<T>> messageMatchers) {
Assert.notEmpty(messageMatchers, "messageMatchers must contain a value");
if (messageMatchers.contains(null)) {
throw new IllegalArgumentException("messageMatchers cannot contain null values");
}
Assert.isTrue(!messageMatchers.contains(null), "messageMatchers cannot contain null values");
this.messageMatchers = messageMatchers;
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher;
import java.util.List;
import org.springframework.core.log.LogMessage;
import org.springframework.messaging.Message;
/**
@ -49,15 +50,13 @@ public final class AndMessageMatcher<T> extends AbstractMessageMatcherComposite<
@Override
public boolean matches(Message<? extends T> message) {
for (MessageMatcher<T> matcher : getMessageMatchers()) {
if (this.LOGGER.isDebugEnabled()) {
this.LOGGER.debug("Trying to match using " + matcher);
}
this.logger.debug(LogMessage.format("Trying to match using %s", matcher));
if (!matcher.matches(message)) {
this.LOGGER.debug("Did not match");
this.logger.debug("Did not match");
return false;
}
}
this.LOGGER.debug("All messageMatchers returned true");
this.logger.debug("All messageMatchers returned true");
return true;
}

View File

@ -26,17 +26,11 @@ import org.springframework.messaging.Message;
*/
public interface MessageMatcher<T> {
/**
* Returns true if the {@link Message} matches, else false
* @param message the {@link Message} to match on
* @return true if the {@link Message} matches, else false
*/
boolean matches(Message<? extends T> message);
/**
* Matches every {@link Message}
*/
MessageMatcher<Object> ANY_MESSAGE = new MessageMatcher<Object>() {
@Override
public boolean matches(Message<?> message) {
return true;
@ -46,6 +40,14 @@ public interface MessageMatcher<T> {
public String toString() {
return "ANY_MESSAGE";
}
};
/**
* Returns true if the {@link Message} matches, else false
* @param message the {@link Message} to match on
* @return true if the {@link Message} matches, else false
*/
boolean matches(Message<? extends T> message);
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher;
import java.util.List;
import org.springframework.core.log.LogMessage;
import org.springframework.messaging.Message;
/**
@ -49,15 +50,13 @@ public final class OrMessageMatcher<T> extends AbstractMessageMatcherComposite<T
@Override
public boolean matches(Message<? extends T> message) {
for (MessageMatcher<T> matcher : getMessageMatchers()) {
if (this.LOGGER.isDebugEnabled()) {
this.LOGGER.debug("Trying to match using " + matcher);
}
this.logger.debug(LogMessage.format("Trying to match using %s", matcher));
if (matcher.matches(message)) {
this.LOGGER.debug("matched");
this.logger.debug("matched");
return true;
}
}
this.LOGGER.debug("No matches found");
this.logger.debug("No matches found");
return false;
}

View File

@ -107,11 +107,8 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
private SimpDestinationMessageMatcher(String pattern, SimpMessageType type, PathMatcher pathMatcher) {
Assert.notNull(pattern, "pattern cannot be null");
Assert.notNull(pathMatcher, "pathMatcher cannot be null");
if (!isTypeWithDestination(type)) {
throw new IllegalArgumentException(
"SimpMessageType " + type + " does not contain a destination and so cannot be matched on.");
}
Assert.isTrue(isTypeWithDestination(type),
() -> "SimpMessageType " + type + " does not contain a destination and so cannot be matched on.");
this.matcher = pathMatcher;
this.messageTypeMatcher = (type != null) ? new SimpMessageTypeMatcher(type) : ANY_MESSAGE;
this.pattern = pattern;
@ -122,7 +119,6 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
if (!this.messageTypeMatcher.matches(message)) {
return false;
}
String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders());
return destination != null && this.matcher.match(this.pattern, destination);
}
@ -144,10 +140,7 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
}
private boolean isTypeWithDestination(SimpMessageType type) {
if (type == null) {
return true;
}
return SimpMessageType.MESSAGE.equals(type) || SimpMessageType.SUBSCRIBE.equals(type);
return type == null || SimpMessageType.MESSAGE.equals(type) || SimpMessageType.SUBSCRIBE.equals(type);
}
/**

View File

@ -49,7 +49,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher<Object> {
public boolean matches(Message<?> message) {
MessageHeaders headers = message.getHeaders();
SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
return this.typeToMatch == messageType;
}
@ -63,7 +62,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher<Object> {
}
SimpMessageTypeMatcher otherMatcher = (SimpMessageTypeMatcher) other;
return ObjectUtils.nullSafeEquals(this.typeToMatch, otherMatcher.typeToMatch);
}
@Override

View File

@ -46,23 +46,19 @@ public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter {
if (!this.matcher.matches(message)) {
return message;
}
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
CsrfToken expectedToken = (sessionAttributes != null)
? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
if (expectedToken == null) {
throw new MissingCsrfTokenException(null);
}
String actualTokenValue = SimpMessageHeaderAccessor.wrap(message)
.getFirstNativeHeader(expectedToken.getHeaderName());
boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue);
if (csrfCheckPassed) {
return message;
if (!csrfCheckPassed) {
throw new InvalidCsrfTokenException(expectedToken, actualTokenValue);
}
throw new InvalidCsrfTokenException(expectedToken, actualTokenValue);
return message;
}
}