Polish spring-security-messaging main code
Manually polish `spring-security-messaging` following the formatting and checkstyle fixes. Issue gh-8945
This commit is contained in:
parent
554ef627fb
commit
ad1dbf425f
|
@ -19,8 +19,7 @@ package org.springframework.security.messaging.access.expression;
|
|||
import org.springframework.expression.EvaluationContext;
|
||||
|
||||
/**
|
||||
*
|
||||
* /** Allows post processing the {@link EvaluationContext}
|
||||
* Allows post processing the {@link EvaluationContext}
|
||||
*
|
||||
* <p>
|
||||
* This API is intentionally kept package scope as it may evolve over time.
|
||||
|
|
|
@ -38,6 +38,9 @@ import org.springframework.security.messaging.util.matcher.MessageMatcher;
|
|||
*/
|
||||
public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
|
||||
|
||||
private ExpressionBasedMessageSecurityMetadataSourceFactory() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link MessageSecurityMetadataSource} that uses {@link MessageMatcher}
|
||||
* mapped to Spring Expressions. Each entry is considered in order and only the first
|
||||
|
@ -108,9 +111,7 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
|
|||
public static MessageSecurityMetadataSource createExpressionMessageMetadataSource(
|
||||
LinkedHashMap<MessageMatcher<?>, String> matcherToExpression,
|
||||
SecurityExpressionHandler<Message<Object>> handler) {
|
||||
|
||||
LinkedHashMap<MessageMatcher<?>, Collection<ConfigAttribute>> matcherToAttrs = new LinkedHashMap<>();
|
||||
|
||||
for (Map.Entry<MessageMatcher<?>, String> entry : matcherToExpression.entrySet()) {
|
||||
MessageMatcher<?> matcher = entry.getKey();
|
||||
String rawExpression = entry.getValue();
|
||||
|
@ -121,7 +122,4 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
|
|||
return new DefaultMessageSecurityMetadataSource(matcherToAttrs);
|
||||
}
|
||||
|
||||
private ExpressionBasedMessageSecurityMetadataSourceFactory() {
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ class MessageExpressionConfigAttribute implements ConfigAttribute, EvaluationCon
|
|||
@Override
|
||||
public EvaluationContext postProcess(EvaluationContext ctx, Message<?> message) {
|
||||
if (this.matcher instanceof SimpDestinationMessageMatcher) {
|
||||
final Map<String, String> variables = ((SimpDestinationMessageMatcher) this.matcher)
|
||||
Map<String, String> variables = ((SimpDestinationMessageMatcher) this.matcher)
|
||||
.extractPathVariables(message);
|
||||
for (Map.Entry<String, String> entry : variables.entrySet()) {
|
||||
ctx.setVariable(entry.getKey(), entry.getValue());
|
||||
|
|
|
@ -44,19 +44,15 @@ public class MessageExpressionVoter<T> implements AccessDecisionVoter<Message<T>
|
|||
|
||||
@Override
|
||||
public int vote(Authentication authentication, Message<T> message, Collection<ConfigAttribute> attributes) {
|
||||
assert authentication != null;
|
||||
assert message != null;
|
||||
assert attributes != null;
|
||||
|
||||
Assert.notNull(authentication, "authentication must not be null");
|
||||
Assert.notNull(message, "message must not be null");
|
||||
Assert.notNull(attributes, "attributes must not be null");
|
||||
MessageExpressionConfigAttribute attr = findConfigAttribute(attributes);
|
||||
|
||||
if (attr == null) {
|
||||
return ACCESS_ABSTAIN;
|
||||
}
|
||||
|
||||
EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, message);
|
||||
ctx = attr.postProcess(ctx, message);
|
||||
|
||||
return ExpressionUtils.evaluateAsBoolean(attr.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED;
|
||||
}
|
||||
|
||||
|
|
|
@ -65,11 +65,9 @@ public final class DefaultMessageSecurityMetadataSource implements MessageSecuri
|
|||
@Override
|
||||
public Collection<ConfigAttribute> getAllConfigAttributes() {
|
||||
Set<ConfigAttribute> allAttributes = new HashSet<>();
|
||||
|
||||
for (Collection<ConfigAttribute> entry : this.messageMap.values()) {
|
||||
allAttributes.addAll(entry);
|
||||
}
|
||||
|
||||
return allAttributes;
|
||||
}
|
||||
|
||||
|
|
|
@ -98,26 +98,20 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
|
|||
return null;
|
||||
}
|
||||
Object principal = authentication.getPrincipal();
|
||||
|
||||
AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter);
|
||||
|
||||
String expressionToParse = authPrincipal.expression();
|
||||
if (StringUtils.hasLength(expressionToParse)) {
|
||||
StandardEvaluationContext context = new StandardEvaluationContext();
|
||||
context.setRootObject(principal);
|
||||
context.setVariable("this", principal);
|
||||
|
||||
Expression expression = this.parser.parseExpression(expressionToParse);
|
||||
principal = expression.getValue(context);
|
||||
}
|
||||
|
||||
if (principal != null && !parameter.getParameterType().isAssignableFrom(principal.getClass())) {
|
||||
if (authPrincipal.errorOnInvalidType()) {
|
||||
throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
|
||||
}
|
||||
else {
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
return principal;
|
||||
}
|
||||
|
|
|
@ -43,9 +43,9 @@ import org.springframework.util.Assert;
|
|||
public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter
|
||||
implements ExecutorChannelInterceptor {
|
||||
|
||||
private final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
|
||||
private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
|
||||
|
||||
private static final ThreadLocal<Stack<SecurityContext>> ORIGINAL_CONTEXT = new ThreadLocal<>();
|
||||
private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>();
|
||||
|
||||
private final String authenticationHeaderName;
|
||||
|
||||
|
@ -110,46 +110,41 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA
|
|||
|
||||
private void setup(Message<?> message) {
|
||||
SecurityContext currentContext = SecurityContextHolder.getContext();
|
||||
|
||||
Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get();
|
||||
Stack<SecurityContext> contextStack = originalContext.get();
|
||||
if (contextStack == null) {
|
||||
contextStack = new Stack<>();
|
||||
ORIGINAL_CONTEXT.set(contextStack);
|
||||
originalContext.set(contextStack);
|
||||
}
|
||||
contextStack.push(currentContext);
|
||||
|
||||
Object user = message.getHeaders().get(this.authenticationHeaderName);
|
||||
|
||||
Authentication authentication;
|
||||
if ((user instanceof Authentication)) {
|
||||
authentication = (Authentication) user;
|
||||
}
|
||||
else {
|
||||
authentication = this.anonymous;
|
||||
}
|
||||
Authentication authentication = getAuthentication(user);
|
||||
SecurityContext context = SecurityContextHolder.createEmptyContext();
|
||||
context.setAuthentication(authentication);
|
||||
SecurityContextHolder.setContext(context);
|
||||
}
|
||||
|
||||
private void cleanup() {
|
||||
Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get();
|
||||
private Authentication getAuthentication(Object user) {
|
||||
if ((user instanceof Authentication)) {
|
||||
return (Authentication) user;
|
||||
}
|
||||
return this.anonymous;
|
||||
}
|
||||
|
||||
private void cleanup() {
|
||||
Stack<SecurityContext> contextStack = originalContext.get();
|
||||
if (contextStack == null || contextStack.isEmpty()) {
|
||||
SecurityContextHolder.clearContext();
|
||||
ORIGINAL_CONTEXT.remove();
|
||||
originalContext.remove();
|
||||
return;
|
||||
}
|
||||
|
||||
SecurityContext originalContext = contextStack.pop();
|
||||
|
||||
SecurityContext context = contextStack.pop();
|
||||
try {
|
||||
if (this.EMPTY_CONTEXT.equals(originalContext)) {
|
||||
if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) {
|
||||
SecurityContextHolder.clearContext();
|
||||
ORIGINAL_CONTEXT.remove();
|
||||
originalContext.remove();
|
||||
}
|
||||
else {
|
||||
SecurityContextHolder.setContext(originalContext);
|
||||
SecurityContextHolder.setContext(context);
|
||||
}
|
||||
}
|
||||
catch (Throwable ex) {
|
||||
|
|
|
@ -134,28 +134,21 @@ public class AuthenticationPrincipalArgumentResolver implements HandlerMethodArg
|
|||
|
||||
private Object resolvePrincipal(MethodParameter parameter, Object principal) {
|
||||
AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter);
|
||||
|
||||
String expressionToParse = authPrincipal.expression();
|
||||
if (StringUtils.hasLength(expressionToParse)) {
|
||||
StandardEvaluationContext context = new StandardEvaluationContext();
|
||||
context.setRootObject(principal);
|
||||
context.setVariable("this", principal);
|
||||
context.setBeanResolver(this.beanResolver);
|
||||
|
||||
Expression expression = this.parser.parseExpression(expressionToParse);
|
||||
principal = expression.getValue(context);
|
||||
}
|
||||
|
||||
if (isInvalidType(parameter, principal)) {
|
||||
|
||||
if (authPrincipal.errorOnInvalidType()) {
|
||||
throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
|
||||
}
|
||||
else {
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
return principal;
|
||||
}
|
||||
|
||||
|
|
|
@ -133,28 +133,21 @@ public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgu
|
|||
|
||||
private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) {
|
||||
CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter);
|
||||
|
||||
String expressionToParse = contextAnno.expression();
|
||||
if (StringUtils.hasLength(expressionToParse)) {
|
||||
StandardEvaluationContext context = new StandardEvaluationContext();
|
||||
context.setRootObject(securityContext);
|
||||
context.setVariable("this", securityContext);
|
||||
context.setBeanResolver(this.beanResolver);
|
||||
|
||||
Expression expression = this.parser.parseExpression(expressionToParse);
|
||||
securityContext = expression.getValue(context);
|
||||
}
|
||||
|
||||
if (isInvalidType(parameter, securityContext)) {
|
||||
|
||||
if (contextAnno.errorOnInvalidType()) {
|
||||
throw new ClassCastException(securityContext + " is not assignable to " + parameter.getParameterType());
|
||||
}
|
||||
else {
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
return securityContext;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,13 @@ import org.springframework.util.Assert;
|
|||
*/
|
||||
public abstract class AbstractMessageMatcherComposite<T> implements MessageMatcher<T> {
|
||||
|
||||
protected final Log LOGGER = LogFactory.getLog(getClass());
|
||||
protected final Log logger = LogFactory.getLog(getClass());
|
||||
|
||||
/**
|
||||
* @deprecated since 5.4 in favor of {@link #logger}
|
||||
*/
|
||||
@Deprecated
|
||||
protected final Log LOGGER = this.logger;
|
||||
|
||||
private final List<MessageMatcher<T>> messageMatchers;
|
||||
|
||||
|
@ -41,9 +47,7 @@ public abstract class AbstractMessageMatcherComposite<T> implements MessageMatch
|
|||
*/
|
||||
AbstractMessageMatcherComposite(List<MessageMatcher<T>> messageMatchers) {
|
||||
Assert.notEmpty(messageMatchers, "messageMatchers must contain a value");
|
||||
if (messageMatchers.contains(null)) {
|
||||
throw new IllegalArgumentException("messageMatchers cannot contain null values");
|
||||
}
|
||||
Assert.isTrue(!messageMatchers.contains(null), "messageMatchers cannot contain null values");
|
||||
this.messageMatchers = messageMatchers;
|
||||
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.core.log.LogMessage;
|
||||
import org.springframework.messaging.Message;
|
||||
|
||||
/**
|
||||
|
@ -49,15 +50,13 @@ public final class AndMessageMatcher<T> extends AbstractMessageMatcherComposite<
|
|||
@Override
|
||||
public boolean matches(Message<? extends T> message) {
|
||||
for (MessageMatcher<T> matcher : getMessageMatchers()) {
|
||||
if (this.LOGGER.isDebugEnabled()) {
|
||||
this.LOGGER.debug("Trying to match using " + matcher);
|
||||
}
|
||||
this.logger.debug(LogMessage.format("Trying to match using %s", matcher));
|
||||
if (!matcher.matches(message)) {
|
||||
this.LOGGER.debug("Did not match");
|
||||
this.logger.debug("Did not match");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
this.LOGGER.debug("All messageMatchers returned true");
|
||||
this.logger.debug("All messageMatchers returned true");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -26,17 +26,11 @@ import org.springframework.messaging.Message;
|
|||
*/
|
||||
public interface MessageMatcher<T> {
|
||||
|
||||
/**
|
||||
* Returns true if the {@link Message} matches, else false
|
||||
* @param message the {@link Message} to match on
|
||||
* @return true if the {@link Message} matches, else false
|
||||
*/
|
||||
boolean matches(Message<? extends T> message);
|
||||
|
||||
/**
|
||||
* Matches every {@link Message}
|
||||
*/
|
||||
MessageMatcher<Object> ANY_MESSAGE = new MessageMatcher<Object>() {
|
||||
|
||||
@Override
|
||||
public boolean matches(Message<?> message) {
|
||||
return true;
|
||||
|
@ -46,6 +40,14 @@ public interface MessageMatcher<T> {
|
|||
public String toString() {
|
||||
return "ANY_MESSAGE";
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns true if the {@link Message} matches, else false
|
||||
* @param message the {@link Message} to match on
|
||||
* @return true if the {@link Message} matches, else false
|
||||
*/
|
||||
boolean matches(Message<? extends T> message);
|
||||
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.core.log.LogMessage;
|
||||
import org.springframework.messaging.Message;
|
||||
|
||||
/**
|
||||
|
@ -49,15 +50,13 @@ public final class OrMessageMatcher<T> extends AbstractMessageMatcherComposite<T
|
|||
@Override
|
||||
public boolean matches(Message<? extends T> message) {
|
||||
for (MessageMatcher<T> matcher : getMessageMatchers()) {
|
||||
if (this.LOGGER.isDebugEnabled()) {
|
||||
this.LOGGER.debug("Trying to match using " + matcher);
|
||||
}
|
||||
this.logger.debug(LogMessage.format("Trying to match using %s", matcher));
|
||||
if (matcher.matches(message)) {
|
||||
this.LOGGER.debug("matched");
|
||||
this.logger.debug("matched");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
this.LOGGER.debug("No matches found");
|
||||
this.logger.debug("No matches found");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -107,11 +107,8 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
|
|||
private SimpDestinationMessageMatcher(String pattern, SimpMessageType type, PathMatcher pathMatcher) {
|
||||
Assert.notNull(pattern, "pattern cannot be null");
|
||||
Assert.notNull(pathMatcher, "pathMatcher cannot be null");
|
||||
if (!isTypeWithDestination(type)) {
|
||||
throw new IllegalArgumentException(
|
||||
"SimpMessageType " + type + " does not contain a destination and so cannot be matched on.");
|
||||
}
|
||||
|
||||
Assert.isTrue(isTypeWithDestination(type),
|
||||
() -> "SimpMessageType " + type + " does not contain a destination and so cannot be matched on.");
|
||||
this.matcher = pathMatcher;
|
||||
this.messageTypeMatcher = (type != null) ? new SimpMessageTypeMatcher(type) : ANY_MESSAGE;
|
||||
this.pattern = pattern;
|
||||
|
@ -122,7 +119,6 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
|
|||
if (!this.messageTypeMatcher.matches(message)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders());
|
||||
return destination != null && this.matcher.match(this.pattern, destination);
|
||||
}
|
||||
|
@ -144,10 +140,7 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
|
|||
}
|
||||
|
||||
private boolean isTypeWithDestination(SimpMessageType type) {
|
||||
if (type == null) {
|
||||
return true;
|
||||
}
|
||||
return SimpMessageType.MESSAGE.equals(type) || SimpMessageType.SUBSCRIBE.equals(type);
|
||||
return type == null || SimpMessageType.MESSAGE.equals(type) || SimpMessageType.SUBSCRIBE.equals(type);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -49,7 +49,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher<Object> {
|
|||
public boolean matches(Message<?> message) {
|
||||
MessageHeaders headers = message.getHeaders();
|
||||
SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
|
||||
|
||||
return this.typeToMatch == messageType;
|
||||
}
|
||||
|
||||
|
@ -63,7 +62,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher<Object> {
|
|||
}
|
||||
SimpMessageTypeMatcher otherMatcher = (SimpMessageTypeMatcher) other;
|
||||
return ObjectUtils.nullSafeEquals(this.typeToMatch, otherMatcher.typeToMatch);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -46,23 +46,19 @@ public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter {
|
|||
if (!this.matcher.matches(message)) {
|
||||
return message;
|
||||
}
|
||||
|
||||
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
|
||||
CsrfToken expectedToken = (sessionAttributes != null)
|
||||
? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
|
||||
|
||||
if (expectedToken == null) {
|
||||
throw new MissingCsrfTokenException(null);
|
||||
}
|
||||
|
||||
String actualTokenValue = SimpMessageHeaderAccessor.wrap(message)
|
||||
.getFirstNativeHeader(expectedToken.getHeaderName());
|
||||
|
||||
boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue);
|
||||
if (csrfCheckPassed) {
|
||||
return message;
|
||||
if (!csrfCheckPassed) {
|
||||
throw new InvalidCsrfTokenException(expectedToken, actualTokenValue);
|
||||
}
|
||||
throw new InvalidCsrfTokenException(expectedToken, actualTokenValue);
|
||||
return message;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue