SEC-2830: Provide Same Origin support for SockJS

This commit is contained in:
Rob Winch 2015-02-18 09:14:10 -06:00
parent a27c33754c
commit 6a8475adbb
13 changed files with 872 additions and 6 deletions

View File

@ -15,9 +15,14 @@
*/ */
package org.springframework.security.config.annotation.web.configurers; package org.springframework.security.config.annotation.web.configurers;
import java.util.ArrayList;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.access.AccessDeniedHandler;
@ -31,6 +36,9 @@ import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler; import org.springframework.security.web.session.InvalidSessionAccessDeniedHandler;
import org.springframework.security.web.session.InvalidSessionStrategy; import org.springframework.security.web.session.InvalidSessionStrategy;
import org.springframework.security.web.util.matcher.AndRequestMatcher;
import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -66,7 +74,8 @@ import org.springframework.util.Assert;
*/ */
public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends AbstractHttpConfigurer<CsrfConfigurer<H>,H> { public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends AbstractHttpConfigurer<CsrfConfigurer<H>,H> {
private CsrfTokenRepository csrfTokenRepository = new HttpSessionCsrfTokenRepository(); private CsrfTokenRepository csrfTokenRepository = new HttpSessionCsrfTokenRepository();
private RequestMatcher requireCsrfProtectionMatcher; private RequestMatcher requireCsrfProtectionMatcher = CsrfFilter.DEFAULT_CSRF_MATCHER;
private List<RequestMatcher> ignoredCsrfProtectionMatchers = new ArrayList<RequestMatcher>();
/** /**
* Creates a new instance * Creates a new instance
@ -102,10 +111,38 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends Abst
return this; return this;
} }
/**
* <p>
* Allows specifying {@link HttpServletRequest} that should not use CSRF Protection even if they match the {@link #requireCsrfProtectionMatcher(RequestMatcher)}.
* </p>
*
* <p>
* The following will ensure CSRF protection ignores:
* </p>
* <ul>
* <li>Any GET, HEAD, TRACE, OPTIONS (this is the default)</li>
* <li>We also explicitly state to ignore any request that starts with "/sockjs/"</li>
* </ul>
*
* <pre>
* http
* .csrf()
* .ignoringAntMatchers("/sockjs/**")
* .and()
* ...
* </pre>
*
* @since 4.0
*/
public CsrfConfigurer<H> ignoringAntMatchers(String... antPatterns) {
return new IgnoreCsrfProtectionRegistry().antMatchers(antPatterns).and();
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public void configure(H http) throws Exception { public void configure(H http) throws Exception {
CsrfFilter filter = new CsrfFilter(csrfTokenRepository); CsrfFilter filter = new CsrfFilter(csrfTokenRepository);
RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
if(requireCsrfProtectionMatcher != null) { if(requireCsrfProtectionMatcher != null) {
filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
} }
@ -125,6 +162,18 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends Abst
http.addFilter(filter); http.addFilter(filter);
} }
/**
* Gets the final {@link RequestMatcher} to use by combining the {@link #requireCsrfProtectionMatcher(RequestMatcher)} and any {@link #ignore()}.
*
* @return the {@link RequestMatcher} to use
*/
private RequestMatcher getRequireCsrfProtectionMatcher() {
if(ignoredCsrfProtectionMatchers.isEmpty()) {
return requireCsrfProtectionMatcher;
}
return new AndRequestMatcher(requireCsrfProtectionMatcher, new NegatedRequestMatcher(new OrRequestMatcher(ignoredCsrfProtectionMatchers)));
}
/** /**
* Gets the default {@link AccessDeniedHandler} from the * Gets the default {@link AccessDeniedHandler} from the
* {@link ExceptionHandlingConfigurer#getAccessDeniedHandler()} or create a * {@link ExceptionHandlingConfigurer#getAccessDeniedHandler()} or create a
@ -190,4 +239,25 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>> extends Abst
handlers.put(MissingCsrfTokenException.class, invalidSessionDeniedHandler); handlers.put(MissingCsrfTokenException.class, invalidSessionDeniedHandler);
return new DelegatingAccessDeniedHandler(handlers, defaultAccessDeniedHandler); return new DelegatingAccessDeniedHandler(handlers, defaultAccessDeniedHandler);
} }
/**
* Allows registering {@link RequestMatcher} instances that should be
* ignored (even if the {@link HttpServletRequest} matches the
* {@link CsrfConfigurer#requireCsrfProtectionMatcher(RequestMatcher)}.
*
* @author Rob Winch
* @since 4.0
*/
private class IgnoreCsrfProtectionRegistry extends AbstractRequestMatcherRegistry<IgnoreCsrfProtectionRegistry>{
public CsrfConfigurer<H> and() {
return CsrfConfigurer.this;
}
protected IgnoreCsrfProtectionRegistry chainRequestMatchers(
List<RequestMatcher> requestMatchers) {
ignoredCsrfProtectionMatchers.addAll(requestMatchers);
return this;
}
}
} }

View File

@ -15,6 +15,9 @@
*/ */
package org.springframework.security.config.annotation.web.socket; package org.springframework.security.config.annotation.web.socket;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
@ -28,11 +31,20 @@ import org.springframework.security.messaging.access.intercept.ChannelSecurityIn
import org.springframework.security.messaging.access.intercept.MessageSecurityMetadataSource; import org.springframework.security.messaging.access.intercept.MessageSecurityMetadataSource;
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* Allows configuring WebSocket Authorization. * Allows configuring WebSocket Authorization.
@ -58,9 +70,12 @@ import java.util.List;
* @author Rob Winch * @author Rob Winch
*/ */
@Order(Ordered.HIGHEST_PRECEDENCE + 100) @Order(Ordered.HIGHEST_PRECEDENCE + 100)
public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer { public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer
implements SmartInitializingSingleton {
private final WebSocketMessageSecurityMetadataSourceRegistry inboundRegistry = new WebSocketMessageSecurityMetadataSourceRegistry(); private final WebSocketMessageSecurityMetadataSourceRegistry inboundRegistry = new WebSocketMessageSecurityMetadataSourceRegistry();
private ApplicationContext context;
public void registerStompEndpoints(StompEndpointRegistry registry) {} public void registerStompEndpoints(StompEndpointRegistry registry) {}
@Override @Override
@ -69,16 +84,34 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A
argumentResolvers.add(new AuthenticationPrincipalArgumentResolver()); argumentResolvers.add(new AuthenticationPrincipalArgumentResolver());
} }
@Override @Override
public final void configureClientInboundChannel(ChannelRegistration registration) { public final void configureClientInboundChannel(ChannelRegistration registration) {
ChannelSecurityInterceptor inboundChannelSecurity = inboundChannelSecurity(); ChannelSecurityInterceptor inboundChannelSecurity = inboundChannelSecurity();
registration.setInterceptors(securityContextChannelInterceptor());
if(sameOriginEnforced()) {
registration.setInterceptors(csrfChannelInterceptor());
}
if(inboundRegistry.containsMapping()) { if(inboundRegistry.containsMapping()) {
registration.setInterceptors(securityContextChannelInterceptor(),inboundChannelSecurity); registration.setInterceptors(inboundChannelSecurity);
} }
customizeClientInboundChannel(registration); customizeClientInboundChannel(registration);
} }
/**
* <p>
* Determines if a CSRF token is required for connecting. This protects against remote sites from connecting to the
* application and being able to read/write data over the connection. The default is true.
* </p>
* <p>
* Subclasses can override this method to disable CSRF protection
* </p>
*
* @return true if a CSRF is required for connecting, else false
*/
protected boolean sameOriginEnforced() {
return true;
}
/** /**
* Allows subclasses to customize the configuration of the {@link ChannelRegistration}. * Allows subclasses to customize the configuration of the {@link ChannelRegistration}.
* *
@ -87,6 +120,11 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A
protected void customizeClientInboundChannel(ChannelRegistration registration) { protected void customizeClientInboundChannel(ChannelRegistration registration) {
} }
@Bean
public CsrfChannelInterceptor csrfChannelInterceptor() {
return new CsrfChannelInterceptor();
}
@Bean @Bean
public ChannelSecurityInterceptor inboundChannelSecurity() { public ChannelSecurityInterceptor inboundChannelSecurity() {
ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor(inboundMessageSecurityMetadataSource()); ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor(inboundMessageSecurityMetadataSource());
@ -125,4 +163,47 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A
return super.containsMapping(); return super.containsMapping();
} }
} }
@Autowired
public void setApplicationContext(ApplicationContext context) {
this.context = context;
}
public void afterSingletonsInstantiated() {
if(!sameOriginEnforced()) {
return;
}
String beanName = "stompWebSocketHandlerMapping";
SimpleUrlHandlerMapping mapping = context.getBean(beanName, SimpleUrlHandlerMapping.class);
Map<String, Object> mappings = mapping.getHandlerMap();
for(Object object : mappings.values()) {
if(object instanceof SockJsHttpRequestHandler) {
SockJsHttpRequestHandler sockjsHandler = (SockJsHttpRequestHandler) object;
SockJsService sockJsService = sockjsHandler.getSockJsService();
if(!(sockJsService instanceof TransportHandlingSockJsService)) {
throw new IllegalStateException("sockJsService must be instance of TransportHandlingSockJsService got " + sockJsService);
}
TransportHandlingSockJsService transportHandlingSockJsService = (TransportHandlingSockJsService) sockJsService;
List<HandshakeInterceptor> handshakeInterceptors = transportHandlingSockJsService.getHandshakeInterceptors();
List<HandshakeInterceptor> interceptorsToSet = new ArrayList<HandshakeInterceptor>(handshakeInterceptors.size() + 1);
interceptorsToSet.add(new CsrfTokenHandshakeInterceptor());
interceptorsToSet.addAll(handshakeInterceptors);
transportHandlingSockJsService.setHandshakeInterceptors(interceptorsToSet);
}
else if(object instanceof WebSocketHttpRequestHandler) {
WebSocketHttpRequestHandler handler = (WebSocketHttpRequestHandler) object;
List<HandshakeInterceptor> handshakeInterceptors = handler.getHandshakeInterceptors();
List<HandshakeInterceptor> interceptorsToSet = new ArrayList<HandshakeInterceptor>(handshakeInterceptors.size() + 1);
interceptorsToSet.add(new CsrfTokenHandshakeInterceptor());
interceptorsToSet.addAll(handshakeInterceptors);
handler.setHandshakeInterceptors(interceptorsToSet);
} else {
throw new IllegalStateException("Bean " + beanName + " is expected to contain mappings to either a SockJsHttpRequestHandler or a WebSocketHttpRequestHandler but got " + object);
}
}
}
} }

View File

@ -33,6 +33,8 @@ import org.springframework.security.messaging.access.intercept.ChannelSecurityIn
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher; import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher;
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
import org.w3c.dom.Element; import org.w3c.dom.Element;
@ -152,7 +154,8 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition
String[] beanNames = registry.getBeanDefinitionNames(); String[] beanNames = registry.getBeanDefinitionNames();
for(String beanName : beanNames) { for(String beanName : beanNames) {
BeanDefinition bd = registry.getBeanDefinition(beanName); BeanDefinition bd = registry.getBeanDefinition(beanName);
if(bd.getBeanClassName().equals(SimpAnnotationMethodMessageHandler.class.getName())) { String beanClassName = bd.getBeanClassName();
if(beanClassName.equals(SimpAnnotationMethodMessageHandler.class.getName())) {
PropertyValue current = bd.getPropertyValues().getPropertyValue(CUSTOM_ARG_RESOLVERS_PROP); PropertyValue current = bd.getPropertyValues().getPropertyValue(CUSTOM_ARG_RESOLVERS_PROP);
ManagedList<Object> argResolvers = new ManagedList<Object>(); ManagedList<Object> argResolvers = new ManagedList<Object>();
if(current != null) { if(current != null) {
@ -161,6 +164,13 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition
argResolvers.add(new RootBeanDefinition(AuthenticationPrincipalArgumentResolver.class)); argResolvers.add(new RootBeanDefinition(AuthenticationPrincipalArgumentResolver.class));
bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers); bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers);
} }
else if(beanClassName.equals("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler")) {
addCsrfTokenHandshakeInterceptor(bd);
} else if(beanClassName.equals("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService")) {
addCsrfTokenHandshakeInterceptor(bd);
} else if(beanClassName.equals("org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService")) {
addCsrfTokenHandshakeInterceptor(bd);
}
} }
if(!registry.containsBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID)) { if(!registry.containsBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID)) {
@ -168,6 +178,7 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition
} }
ManagedList<Object> interceptors = new ManagedList(); ManagedList<Object> interceptors = new ManagedList();
interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class)); interceptors.add(new RootBeanDefinition(SecurityContextChannelInterceptor.class));
interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class));
interceptors.add(registry.getBeanDefinition(inboundSecurityInterceptorId)); interceptors.add(registry.getBeanDefinition(inboundSecurityInterceptorId));
BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID); BeanDefinition inboundChannel = registry.getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID);
@ -180,6 +191,14 @@ public final class MessageSecurityBeanDefinitionParser implements BeanDefinition
inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, interceptors); inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, interceptors);
} }
private void addCsrfTokenHandshakeInterceptor(BeanDefinition bd) {
String interceptorPropertyName = "handshakeInterceptors";
ManagedList<? super Object> interceptors = new ManagedList<Object>();
interceptors.add(new RootBeanDefinition(CsrfTokenHandshakeInterceptor.class));
interceptors.addAll((ManagedList<Object>)bd.getPropertyValues().get(interceptorPropertyName));
bd.getPropertyValues().add(interceptorPropertyName, interceptors);
}
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
} }

View File

@ -12,12 +12,24 @@ import org.springframework.http.server.ServerHttpRequest
import org.springframework.http.server.ServerHttpResponse import org.springframework.http.server.ServerHttpResponse
import org.springframework.messaging.handler.annotation.MessageMapping import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver
import org.springframework.messaging.simp.SimpMessageType
import org.springframework.mock.web.MockHttpServletRequest
import org.springframework.mock.web.MockHttpServletResponse
import org.springframework.security.core.Authentication import org.springframework.security.core.Authentication
import org.springframework.security.core.annotation.AuthenticationPrincipal import org.springframework.security.core.annotation.AuthenticationPrincipal
import org.springframework.security.web.csrf.CsrfToken
import org.springframework.security.web.csrf.DefaultCsrfToken
import org.springframework.security.web.csrf.MissingCsrfTokenException
import org.springframework.stereotype.Controller import org.springframework.stereotype.Controller
import org.springframework.web.servlet.HandlerMapping
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping
import org.springframework.web.socket.WebSocketHandler import org.springframework.web.socket.WebSocketHandler
import org.springframework.web.socket.server.HandshakeFailureException import org.springframework.web.socket.server.HandshakeFailureException
import org.springframework.web.socket.server.HandshakeHandler import org.springframework.web.socket.server.HandshakeHandler
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler
import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler
import static org.mockito.Mockito.* import static org.mockito.Mockito.*
@ -37,6 +49,7 @@ import org.springframework.security.core.context.SecurityContextHolder
*/ */
class MessagesConfigTests extends AbstractXmlConfigTests { class MessagesConfigTests extends AbstractXmlConfigTests {
Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER') Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
boolean useSockJS = false
def cleanup() { def cleanup() {
SecurityContextHolder.clearContext() SecurityContextHolder.clearContext()
@ -93,6 +106,89 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
controller.authenticationPrincipal == messageUser.name controller.authenticationPrincipal == messageUser.name
} }
def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor'() {
setup:
def id = 'authenticationController'
bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor)
messages {
'message-interceptor'(pattern:'/**',access:'permitAll')
}
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
Message<?> message = message(headers,'/authentication')
WebSocketHttpRequestHandler handler = appContext.getBean(WebSocketHttpRequestHandler)
MockHttpServletRequest request = new MockHttpServletRequest()
String sessionAttr = "sessionAttr"
request.getSession().setAttribute(sessionAttr,"sessionValue")
CsrfToken token = new DefaultCsrfToken("header", "param", "token")
request.setAttribute(CsrfToken.name, token)
when:
handler.handleRequest(request , new MockHttpServletResponse())
TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler)
then: 'CsrfToken is populated'
handshakeHandler.attributes.get(CsrfToken.name) == token
and: 'Explicitly listed HandshakeInterceptor are not overridden'
handshakeHandler.attributes.get(sessionAttr) == request.getSession().getAttribute(sessionAttr)
}
def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor with SockJS'() {
setup:
useSockJS = true
def id = 'authenticationController'
bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor)
messages {
'message-interceptor'(pattern:'/**',access:'permitAll')
}
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
Message<?> message = message(headers,'/authentication')
SockJsHttpRequestHandler handler = appContext.getBean(SockJsHttpRequestHandler)
MockHttpServletRequest request = new MockHttpServletRequest()
String sessionAttr = "sessionAttr"
request.getSession().setAttribute(sessionAttr,"sessionValue")
CsrfToken token = new DefaultCsrfToken("header", "param", "token")
request.setAttribute(CsrfToken.name, token)
request.setMethod("GET")
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket")
when:
handler.handleRequest(request , new MockHttpServletResponse())
TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler)
then: 'CsrfToken is populated'
handshakeHandler.attributes?.get(CsrfToken.name) == token
and: 'Explicitly listed HandshakeInterceptor are not overridden'
handshakeHandler.attributes?.get(sessionAttr) == request.getSession().getAttribute(sessionAttr)
}
def 'messages of type CONNECT require valid CsrfToken'() {
setup:
def id = 'authenticationController'
bean(id,MyController)
bean('inPostProcessor',InboundExecutorPostProcessor)
messages {
'message-interceptor'(pattern:'/**',access:'permitAll')
}
when: 'message of type CONNECTION is sent without CsrfTOken'
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
Message<?> message = message(headers,'/authentication')
clientInboundChannel.send(message)
then: 'CSRF Protection blocks the Message'
MessageDeliveryException expected = thrown()
expected.cause instanceof MissingCsrfTokenException
}
def 'messages with no id does not override customArgumentResolvers'() { def 'messages with no id does not override customArgumentResolvers'() {
setup: setup:
def id = 'authenticationController' def id = 'authenticationController'
@ -201,6 +297,12 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
'websocket:transport' {} 'websocket:transport' {}
'websocket:stomp-endpoint'(path:'/app') { 'websocket:stomp-endpoint'(path:'/app') {
'websocket:handshake-handler'(ref:'testHandler') {} 'websocket:handshake-handler'(ref:'testHandler') {}
'websocket:handshake-interceptors' {
'b:bean'('class':HttpSessionHandshakeInterceptor.name) {}
}
if(useSockJS) {
'websocket:sockjs' {}
}
} }
'websocket:simple-broker'(prefix:"/queue, /topic"){} 'websocket:simple-broker'(prefix:"/queue, /topic"){}
} }
@ -214,6 +316,11 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
def message(String destination) { def message(String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create() SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create()
message(headers, destination)
}
def message(SimpMessageHeaderAccessor headers, String destination) {
messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
headers.sessionId = '123' headers.sessionId = '123'
headers.sessionAttributes = [:] headers.sessionAttributes = [:]
headers.destination = destination headers.destination = destination
@ -257,8 +364,15 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
} }
static class TestHandshakeHandler implements HandshakeHandler { static class TestHandshakeHandler implements HandshakeHandler {
@Override Map<String, Object> attributes;
boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException { boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
this.attributes = attributes
if(wsHandler instanceof SockJsWebSocketHandler) {
// work around SPR-12716
SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler;
this.attributes = sockJs.sockJsSession.attributes
}
true true
} }
} }

View File

@ -122,6 +122,120 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal()); assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal());
} }
@Test
public void addsAuthenticationPrincipalResolverWhenNoAuthorization() throws InterruptedException {
loadConfig(NoInboundSecurityConfig.class);
MessageChannel messageChannel = clientInboundChannel();
Message<String> message = message("/permitAll/authentication");
messageChannel.send(message);
assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal());
}
@Test
public void addsCsrfProtectionWhenNoAuthorization() throws InterruptedException {
loadConfig(NoInboundSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MessageChannel messageChannel = clientInboundChannel();
try {
messageChannel.send(message);
fail("Expected Exception");
} catch(MessageDeliveryException success) {
assertThat(success.getCause()).isInstanceOf(MissingCsrfTokenException.class);
}
}
@Test
public void csrfProtectionForConnect() throws InterruptedException {
loadConfig(SockJsSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MessageChannel messageChannel = clientInboundChannel();
try {
messageChannel.send(message);
fail("Expected Exception");
} catch(MessageDeliveryException success) {
assertThat(success.getCause()).isInstanceOf(MissingCsrfTokenException.class);
}
}
@Test
public void csrfProtectionDisabledForConnect() throws InterruptedException {
loadConfig(CsrfDisabledSockJsSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/permitAll/connect");
MessageChannel messageChannel = clientInboundChannel();
messageChannel.send(message);
}
@Test
public void messagesConnectUseCsrfTokenHandshakeInterceptor() throws Exception {
loadConfig(SockJsSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MockHttpServletRequest request = sockjsHttpRequest("/chat");
HttpRequestHandler handler = handler(request);
handler.handleRequest(request, new MockHttpServletResponse());
assertHandshake(request);
}
@Test
public void messagesConnectUseCsrfTokenHandshakeInterceptorMultipleMappings() throws Exception {
loadConfig(SockJsSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MockHttpServletRequest request = sockjsHttpRequest("/other");
HttpRequestHandler handler = handler(request);
handler.handleRequest(request, new MockHttpServletResponse());
assertHandshake(request);
}
@Test
public void messagesConnectWebSocketUseCsrfTokenHandshakeInterceptor() throws Exception {
loadConfig(WebSocketSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MockHttpServletRequest request = websocketHttpRequest("/websocket");
HttpRequestHandler handler = handler(request);
handler.handleRequest(request, new MockHttpServletResponse());
assertHandshake(request);
}
private void assertHandshake(HttpServletRequest request) {
TestHandshakeHandler handshakeHandler = context.getBean(TestHandshakeHandler.class);
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(token);
assertThat(handshakeHandler.attributes.get(sessionAttr)).isEqualTo(request.getSession().getAttribute(sessionAttr));
}
private HttpRequestHandler handler(HttpServletRequest request) throws Exception {
HandlerMapping handlerMapping = context.getBean(HandlerMapping.class);
return (HttpRequestHandler) handlerMapping.getHandler(request).getHandler();
}
private MockHttpServletRequest websocketHttpRequest(String mapping) {
MockHttpServletRequest request = sockjsHttpRequest(mapping);
request.setRequestURI(mapping);
return request;
}
private MockHttpServletRequest sockjsHttpRequest(String mapping) { private MockHttpServletRequest sockjsHttpRequest(String mapping) {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("GET"); request.setMethod("GET");
@ -255,6 +369,75 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
} }
} }
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class NoInboundSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer {
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry
.addEndpoint("/other")
.withSockJS()
.setInterceptors(new HttpSessionHandshakeInterceptor());
registry
.addEndpoint("/chat")
.withSockJS()
.setInterceptors(new HttpSessionHandshakeInterceptor());
}
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableSimpleBroker("/queue/", "/topic/");
registry.setApplicationDestinationPrefixes("/permitAll", "/denyAll");
}
@Bean
public MyController myController() {
return new MyController();
}
}
@Configuration
static class CsrfDisabledSockJsSecurityConfig extends SockJsSecurityConfig {
@Override
protected boolean sameOriginEnforced() {
return false;
}
}
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class WebSocketSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer {
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry
.addEndpoint("/websocket")
.setHandshakeHandler(testHandshakeHandler())
.addInterceptors(new HttpSessionHandshakeInterceptor());
}
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
messages
.simpDestMatchers("/permitAll/**").permitAll()
.anyMessage().denyAll();
}
@Bean
public TestHandshakeHandler testHandshakeHandler() {
return new TestHandshakeHandler();
}
}
@Configuration @Configuration
static class SyncExecutorConfig { static class SyncExecutorConfig {
@Bean @Bean

View File

@ -7803,6 +7803,8 @@ The messages attribute has two different modes. If the <<nsa-messages-id>> is no
* Ensure that any SimpAnnotationMethodMessageHandler has the AuthenticationPrincipalArgumentResolver registered as a custom argument resolver. This allows the use of `@AuthenticationPrincipal` to resolve the principal of the current `Authentication` * Ensure that any SimpAnnotationMethodMessageHandler has the AuthenticationPrincipalArgumentResolver registered as a custom argument resolver. This allows the use of `@AuthenticationPrincipal` to resolve the principal of the current `Authentication`
* Ensures that the SecurityContextChannelInterceptor is automatically registered for the clientInboundChannel. This populates the SecurityContextHolder with the user that is found in the Message * Ensures that the SecurityContextChannelInterceptor is automatically registered for the clientInboundChannel. This populates the SecurityContextHolder with the user that is found in the Message
* Ensures that a ChannelSecurityInterceptor is registered with the clientInboundChannel. This allows authorization rules to be specified for a message. * Ensures that a ChannelSecurityInterceptor is registered with the clientInboundChannel. This allows authorization rules to be specified for a message.
* Ensures that a CsrfChannelInterceptor is registered with the clientInboundChannel. This ensures that only requests from the original domain are enabled.
* Ensures that a CsrfTokenHandshakeInterceptor is registered with WebSocketHttpRequestHandler, TransportHandlingSockJsService, or DefaultSockJsService. This ensures that the expected CsrfToken from the HttpServletRequest is copied into the WebSocket Session attributes.
If additional control is necessary, the id can be specified and a ChannelSecurityInterceptor will be assigned to the specified id. All the wiring with Spring's messaging infrastructure can then be done manually. This is more cumbersome, but provides greater control over the configuration. If additional control is necessary, the id can be specified and a ChannelSecurityInterceptor will be assigned to the specified id. All the wiring with Spring's messaging infrastructure can then be done manually. This is more cumbersome, but provides greater control over the configuration.

View File

@ -9,6 +9,10 @@ dependencies {
"org.springframework:spring-expression:$springVersion", "org.springframework:spring-expression:$springVersion",
"org.springframework:spring-messaging:$springVersion" "org.springframework:spring-messaging:$springVersion"
optional project(':spring-security-web'),
"org.springframework:spring-websocket:$springVersion",
"javax.servlet:javax.servlet-api:$servletApiVersion"
testCompile project(':spring-security-core').sourceSets.test.output, testCompile project(':spring-security-core').sourceSets.test.output,
"commons-codec:commons-codec:$commonsCodecVersion", "commons-codec:commons-codec:$commonsCodecVersion",
"org.slf4j:jcl-over-slf4j:$slf4jVersion", "org.slf4j:jcl-over-slf4j:$slf4jVersion",

View File

@ -0,0 +1,62 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.messaging.web.csrf;
import java.util.Map;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.ChannelInterceptorAdapter;
import org.springframework.security.messaging.util.matcher.MessageMatcher;
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
/**
* {@link ChannelInterceptorAdapter} that validates that a valid CSRF is included in the header of any
* {@link SimpMessageType#CONNECT} message. The expected {@link CsrfToken} is populated by CsrfTokenHandshakeInterceptor.
*
* @author Rob Winch
* @since 4.0
*/
public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter {
private final MessageMatcher<Object> matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT);
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
if(!matcher.matches(message)) {
return message;
}
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
CsrfToken expectedToken = sessionAttributes == null ? null : (CsrfToken) sessionAttributes.get(CsrfToken.class.getName());
if(expectedToken == null) {
throw new MissingCsrfTokenException(null);
}
String actualTokenValue = SimpMessageHeaderAccessor.wrap(message).getFirstNativeHeader(expectedToken.getHeaderName());
boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue);
if(csrfCheckPassed) {
return message;
}
throw new InvalidCsrfTokenException(expectedToken, actualTokenValue);
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.messaging.web.socket.server;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
* Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket attributes. This is used as the
* expected CsrfToken when validating connection requests to ensure only the same origin connects.
*
* @author Rob Winch
* @since 4.0
*/
public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor {
public boolean beforeHandshake(ServerHttpRequest request,
ServerHttpResponse response, WebSocketHandler wsHandler,
Map<String, Object> attributes) throws Exception {
HttpServletRequest httpRequest = ((ServletServerHttpRequest)request).getServletRequest();
CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName());
if(token == null) {
return true;
}
attributes.put(CsrfToken.class.getName(), token);
return true;
}
public void afterHandshake(ServerHttpRequest request,
ServerHttpResponse response, WebSocketHandler wsHandler,
Exception exception) {
}
}

View File

@ -0,0 +1,154 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.messaging.web.csrf;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
@RunWith(MockitoJUnitRunner.class)
public class CsrfChannelInterceptorTests {
@Mock
MessageChannel channel;
SimpMessageHeaderAccessor messageHeaders;
CsrfToken token;
CsrfChannelInterceptor interceptor;
@Before
public void setup() {
token = new DefaultCsrfToken("header", "param", "token");
interceptor = new CsrfChannelInterceptor();
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken());
messageHeaders.setSessionAttributes(new HashMap<String,Object>());
messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), token);
}
@Test
public void preSendValidToken() {
interceptor.preSend(message(), channel);
}
@Test
public void preSendIgnoresConnectAck() {
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
interceptor.preSend(message(), channel);
}
@Test
public void preSendIgnoresDisconnect() {
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
interceptor.preSend(message(), channel);
}
@Test
public void preSendIgnoresDisconnectAck() {
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
interceptor.preSend(message(), channel);
}
@Test
public void preSendIgnoresHeartbeat() {
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
interceptor.preSend(message(), channel);
}
@Test
public void preSendIgnoresMessage() {
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
interceptor.preSend(message(), channel);
}
@Test
public void preSendIgnoresOther() {
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER);
interceptor.preSend(message(), channel);
}
@Test
public void preSendIgnoresSubscribe() {
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE);
interceptor.preSend(message(), channel);
}
@Test
public void preSendIgnoresUnsubscribe() {
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE);
interceptor.preSend(message(), channel);
}
@Test(expected = InvalidCsrfTokenException.class)
public void preSendNoToken() {
messageHeaders.removeNativeHeader(token.getHeaderName());
interceptor.preSend(message(), channel);
}
@Test(expected = InvalidCsrfTokenException.class)
public void preSendInvalidToken() {
messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken() + "invalid");
interceptor.preSend(message(), channel);
}
@Test(expected = MissingCsrfTokenException.class)
public void preSendMissingToken() {
messageHeaders.getSessionAttributes().clear();
interceptor.preSend(message(), channel);
}
@Test(expected = MissingCsrfTokenException.class)
public void preSendMissingTokenNullSessionAttributes() {
messageHeaders.setSessionAttributes(null);
interceptor.preSend(message(), channel);
}
private Message<String> message() {
Map<String, Object> headersToCopy = messageHeaders.toMap();
return MessageBuilder
.withPayload("hi")
.copyHeaders(headersToCopy)
.build();
}
}

View File

@ -0,0 +1,83 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package org.springframework.security.messaging.web.socket.server;
import org.junit.Test;
import org.junit.Before;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.web.socket.WebSocketHandler;
import java.util.HashMap;
import java.util.Map;
import static org.fest.assertions.Assertions.assertThat;
/**
*
* @author Rob Winch
*/
@RunWith(MockitoJUnitRunner.class)
public class CsrfTokenHandshakeInterceptorTests {
@Mock
WebSocketHandler wsHandler;
@Mock
ServerHttpResponse response;
Map<String, Object> attributes;
ServerHttpRequest request;
MockHttpServletRequest httpRequest;
CsrfTokenHandshakeInterceptor interceptor;
@Before
public void setup() {
httpRequest = new MockHttpServletRequest();
attributes = new HashMap<String,Object>();
request = new ServletServerHttpRequest(httpRequest);
interceptor = new CsrfTokenHandshakeInterceptor();
}
@Test
public void beforeHandshakeNoAttribute() throws Exception {
interceptor.beforeHandshake(request, response, wsHandler, attributes);
assertThat(attributes).isEmpty();
}
@Test
public void beforeHandshake() throws Exception {
CsrfToken token = new DefaultCsrfToken("header", "param", "token");
httpRequest.setAttribute(CsrfToken.class.getName(), token);
interceptor.beforeHandshake(request, response, wsHandler, attributes);
assertThat(attributes.keySet()).containsOnly(CsrfToken.class.getName());
assertThat(attributes.values()).containsOnly(token);
}
}

View File

@ -55,6 +55,13 @@ import org.springframework.web.filter.OncePerRequestFilter;
* @since 3.2 * @since 3.2
*/ */
public final class CsrfFilter extends OncePerRequestFilter { public final class CsrfFilter extends OncePerRequestFilter {
/**
* The default {@link RequestMatcher} that indicates if CSRF protection is
* required or not. The default is to ignore GET, HEAD, TRACE, OPTIONS and
* process all other requests.
*/
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
private final Log logger = LogFactory.getLog(getClass()); private final Log logger = LogFactory.getLog(getClass());
private final CsrfTokenRepository tokenRepository; private final CsrfTokenRepository tokenRepository;
private RequestMatcher requireCsrfProtectionMatcher = new DefaultRequiresCsrfMatcher(); private RequestMatcher requireCsrfProtectionMatcher = new DefaultRequiresCsrfMatcher();

View File

@ -0,0 +1,33 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.csrf;
import org.junit.Test;
/**
*
* @author Rob Winch
*
*/
public class MissingCsrfTokenExceptionTests {
// CsrfChannelInterceptor requires this to work
@Test
public void nullExpectedTokenDoesNotFail() {
new MissingCsrfTokenException(null);
}
}