SEC-2845: SecurityContextChannelInterceptor support anonymous

This commit is contained in:
Rob Winch 2015-02-18 09:40:25 -06:00
parent 6149f179c2
commit 36fe0d0357
5 changed files with 452 additions and 14 deletions

View File

@ -36,7 +36,7 @@ import org.springframework.security.core.context.SecurityContextHolder
* @author Rob Winch
*/
class MessagesConfigTests extends AbstractXmlConfigTests {
Authentication messageUser
Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
def cleanup() {
SecurityContextHolder.clearContext()
@ -61,6 +61,21 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
clientInboundChannel.send(message('/permitAll'))
}
def 'anonymous authentication supported'() {
setup:
messages {
'message-interceptor'(pattern:'/permitAll',access:'permitAll')
'message-interceptor'(pattern:'/denyAll',access:'denyAll')
}
messageUser = null
when: 'message is sent to the permitAll endpoint with no user'
clientInboundChannel.send(message('/permitAll'))
then: 'access is granted'
noExceptionThrown()
}
def 'messages with no id automatically adds Authentication argument resolver'() {
setup:
def id = 'authenticationController'
@ -198,12 +213,13 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
}
def message(String destination) {
messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create()
headers.sessionId = '123'
headers.sessionAttributes = [:]
headers.destination = destination
headers.user = messageUser
if(messageUser != null) {
headers.user = messageUser
}
new GenericMessage<String>("hi",headers.messageHeaders)
}

View File

@ -0,0 +1,265 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package org.springframework.security.config.annotation.web.socket;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.MethodParameter;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletConfig;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler;
import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession;
import javax.servlet.http.HttpServletRequest;
import java.util.HashMap;
import java.util.Map;
import static org.fest.assertions.Assertions.assertThat;
import static org.junit.Assert.fail;
public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
AnnotationConfigWebApplicationContext context;
TestingAuthenticationToken messageUser;
CsrfToken token;
String sessionAttr;
@Before
public void setup() {
token = new DefaultCsrfToken("header", "param", "token");
sessionAttr = "sessionAttr";
messageUser = new TestingAuthenticationToken("user","pass","ROLE_USER");
}
@After
public void cleanup() {
if(context != null) {
context.close();
}
}
@Test
public void simpleRegistryMappings() {
loadConfig(SockJsSecurityConfig.class);
clientInboundChannel().send(message("/permitAll"));
try {
clientInboundChannel().send(message("/denyAll"));
fail("Expected Exception");
} catch(MessageDeliveryException expected) {
assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
}
}
@Test
public void annonymousSupported() {
loadConfig(SockJsSecurityConfig.class);
messageUser = null;
clientInboundChannel().send(message("/permitAll"));
}
@Test
public void addsAuthenticationPrincipalResolver() throws InterruptedException {
loadConfig(SockJsSecurityConfig.class);
MessageChannel messageChannel = clientInboundChannel();
Message<String> message = message("/permitAll/authentication");
messageChannel.send(message);
assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal());
}
private MockHttpServletRequest sockjsHttpRequest(String mapping) {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("GET");
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
request.getSession().setAttribute(sessionAttr,"sessionValue");
request.setAttribute(CsrfToken.class.getName(), token);
return request;
}
private Message<String> message(String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();
return message(headers, destination);
}
private Message<String> message(SimpMessageHeaderAccessor headers, String destination) {
headers.setSessionId("123");
headers.setSessionAttributes(new HashMap<String, Object>());
if(destination != null) {
headers.setDestination(destination);
}
if(messageUser != null) {
headers.setUser(messageUser);
}
return new GenericMessage<String>("hi",headers.getMessageHeaders());
}
private MessageChannel clientInboundChannel() {
return context.getBean("clientInboundChannel", MessageChannel.class);
}
private void loadConfig(Class<?>... configs) {
context = new AnnotationConfigWebApplicationContext();
context.register(configs);
context.setServletConfig(new MockServletConfig());
context.refresh();
}
@Controller
static class MyController {
String authenticationPrincipal;
MyCustomArgument myCustomArgument;
@MessageMapping("/authentication")
public void authentication(@AuthenticationPrincipal String un) {
this.authenticationPrincipal = un;
}
@MessageMapping("/myCustom")
public void myCustom(MyCustomArgument myCustomArgument) {
this.myCustomArgument = myCustomArgument;
}
}
static class MyCustomArgument {
MyCustomArgument(String notDefaultConstr) {}
}
static class MyCustomArgumentResolver implements HandlerMethodArgumentResolver {
@Override
public boolean supportsParameter(MethodParameter parameter) {
return parameter.getParameterType().isAssignableFrom(MyCustomArgument.class);
}
@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
return new MyCustomArgument("");
}
}
static class TestHandshakeHandler implements HandshakeHandler {
Map<String, Object> attributes;
public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
this.attributes = attributes;
if(wsHandler instanceof SockJsWebSocketHandler) {
// work around SPR-12716
SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler;
WebSocketServerSockJsSession session = (WebSocketServerSockJsSession) ReflectionTestUtils.getField(sockJs, "sockJsSession");
this.attributes = session.getAttributes();
}
return true;
}
}
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class SockJsSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer {
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry
.addEndpoint("/other")
.setHandshakeHandler(testHandshakeHandler())
.withSockJS()
.setInterceptors(new HttpSessionHandshakeInterceptor());
registry
.addEndpoint("/chat")
.setHandshakeHandler(testHandshakeHandler())
.withSockJS()
.setInterceptors(new HttpSessionHandshakeInterceptor());
}
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
messages
.simpDestMatchers("/permitAll/**").permitAll()
.anyMessage().denyAll();
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableSimpleBroker("/queue/", "/topic/");
registry.setApplicationDestinationPrefixes("/permitAll", "/denyAll");
}
@Bean
public MyController myController() {
return new MyController();
}
@Bean
public TestHandshakeHandler testHandshakeHandler() {
return new TestHandshakeHandler();
}
}
@Configuration
static class SyncExecutorConfig {
@Bean
public static SyncExecutorSubscribableChannelPostProcessor postProcessor() {
return new SyncExecutorSubscribableChannelPostProcessor();
}
}
}

View File

@ -0,0 +1,42 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package org.springframework.security.config.annotation.web.socket;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
/**
* @author Rob Winch
*/
public class SyncExecutorSubscribableChannelPostProcessor implements BeanPostProcessor {
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
if(bean instanceof ExecutorSubscribableChannel) {
ExecutorSubscribableChannel original = (ExecutorSubscribableChannel) bean;
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
channel.setInterceptors(original.getInterceptors());
return channel;
}
return bean;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
return bean;
}
}

View File

@ -15,13 +15,17 @@
*/
package org.springframework.security.messaging.context;
import java.util.Stack;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptorAdapter;
import org.springframework.messaging.support.ExecutorChannelInterceptor;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.util.Assert;
@ -37,10 +41,12 @@ import org.springframework.util.Assert;
*/
public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter implements ExecutorChannelInterceptor {
private final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
private static final ThreadLocal<SecurityContext> ORIGINAL_CONTEXT = new ThreadLocal<SecurityContext>();
private static final ThreadLocal<Stack<SecurityContext>> ORIGINAL_CONTEXT = new ThreadLocal<Stack<SecurityContext>>();
private final String authenticationHeaderName;
private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
/**
* Creates a new instance using the header of the name {@link SimpMessageHeaderAccessor#USER_HEADER}.
*/
@ -57,6 +63,21 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA
Assert.notNull(authenticationHeaderName, "authenticationHeaderName cannot be null");
this.authenticationHeaderName = authenticationHeaderName;
}
/**
* Allows setting the Authentication used for anonymous authentication. Default is:
*
* <pre>
* new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
* </pre>
*
* @param authentication the Authentication used for anonymous authentication. Cannot be null.
*/
public void setAnonymousAuthentication(Authentication authentication) {
Assert.notNull(authentication, "authentication cannot be null");
this.anonymous = authentication;
}
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
setup(message);
@ -79,25 +100,42 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA
private void setup(Message<?> message) {
SecurityContext currentContext = SecurityContextHolder.getContext();
ORIGINAL_CONTEXT.set(currentContext);
Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get();
if(contextStack == null) {
contextStack = new Stack<SecurityContext>();
ORIGINAL_CONTEXT.set(contextStack);
}
contextStack.push(currentContext);
Object user = message.getHeaders().get(authenticationHeaderName);
if(!(user instanceof Authentication)) {
return;
Authentication authentication;
if((user instanceof Authentication)) {
authentication = (Authentication) user;
} else {
authentication = this.anonymous;
}
Authentication authentication = (Authentication) user;
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
}
private void cleanup() {
SecurityContext originalContext = ORIGINAL_CONTEXT.get();
ORIGINAL_CONTEXT.remove();
Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get();
if(contextStack == null || contextStack.isEmpty()) {
SecurityContextHolder.clearContext();
ORIGINAL_CONTEXT.remove();
return;
}
SecurityContext originalContext = contextStack.pop();
try {
if(EMPTY_CONTEXT.equals(originalContext)) {
SecurityContextHolder.clearContext();
ORIGINAL_CONTEXT.remove();
} else {
SecurityContextHolder.setContext(originalContext);
}

View File

@ -10,8 +10,11 @@ import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
@ -35,10 +38,13 @@ public class SecurityContextChannelInterceptorTests {
SecurityContextChannelInterceptor interceptor;
AnonymousAuthenticationToken expectedAnonymous;
@Before
public void setup() {
authentication = new TestingAuthenticationToken("user","pass", "ROLE_USER");
messageBuilder = MessageBuilder.withPayload("payload");
expectedAnonymous = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
interceptor = new SecurityContextChannelInterceptor();
}
@ -73,20 +79,45 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication);
}
@Test(expected = IllegalArgumentException.class)
public void setAnonymousAuthenticationNull() {
interceptor.setAnonymousAuthentication(null);
}
@Test
public void preSendUsesCustomAnonymous() throws Exception {
expectedAnonymous = new AnonymousAuthenticationToken("customKey", "customAnonymous", AuthorityUtils.createAuthorityList("ROLE_CUSTOM"));
interceptor.setAnonymousAuthentication(expectedAnonymous);
interceptor.preSend(messageBuilder.build(), channel);
assertAnonymous();
}
// SEC-2845
@Test
public void preSendUserNotAuthentication() throws Exception {
messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, principal);
interceptor.preSend(messageBuilder.build(), channel);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
assertAnonymous();
}
// SEC-2845
@Test
public void preSendUserNotSet() throws Exception {
interceptor.preSend(messageBuilder.build(), channel);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
assertAnonymous();
}
// SEC-2845
@Test
public void preSendUserNotSetCustomAnonymous() throws Exception {
interceptor.preSend(messageBuilder.build(), channel);
assertAnonymous();
}
@Test
@ -114,20 +145,22 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication);
}
// SEC-2845
@Test
public void beforeHandleUserNotAuthentication() throws Exception {
messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, principal);
interceptor.beforeHandle(messageBuilder.build(), channel, handler);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
assertAnonymous();
}
// SEC-2845
@Test
public void beforeHandleUserNotSet() throws Exception {
interceptor.beforeHandle(messageBuilder.build(), channel, handler);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
assertAnonymous();
}
@ -147,6 +180,7 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
}
// SEC-2829
@Test
public void restoresOriginalContext() throws Exception {
TestingAuthenticationToken original = new TestingAuthenticationToken("original", "original", "ROLE_USER");
@ -161,4 +195,47 @@ public class SecurityContextChannelInterceptorTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(original);
}
/**
* If a user sends a message when processing another message
*
* @throws Exception
*/
@Test
public void restoresOriginalContextNestedThreeDeep() throws Exception {
AnonymousAuthenticationToken anonymous = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_USER"));
TestingAuthenticationToken origional = new TestingAuthenticationToken("original", "origional", "ROLE_USER");
SecurityContextHolder.getContext().setAuthentication(origional);
messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, authentication);
interceptor.beforeHandle(messageBuilder.build(), channel, handler);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication);
// start send message
messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, null);
interceptor.beforeHandle(messageBuilder.build(), channel, handler);
assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo(anonymous.getName());
interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication);
// end send message
interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(origional);
}
private void assertAnonymous() {
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
assertThat(currentAuthentication).isInstanceOf(AnonymousAuthenticationToken.class);
AnonymousAuthenticationToken anonymous = (AnonymousAuthenticationToken) currentAuthentication;
assertThat(anonymous.getName()).isEqualTo(expectedAnonymous.getName());
assertThat(anonymous.getAuthorities()).containsOnly(expectedAnonymous.getAuthorities().toArray());
assertThat(anonymous.getKeyHash()).isEqualTo(expectedAnonymous.getKeyHash());
}
}