SEC-1167: Introduce more flexible SavedRequest handling. Separated the concept of SavedRequest from SecurityContextHolderAwareFilter since the two are orthogonal requirements. This no longer takes a wrapper class property or uses reflection. SavedRequest functionality is accessed through the RequestCache interface, with the default implementation being HttpSessionRequestCache. A separate filter RequestCacheAwareFilter is now responsible for reconstituting the SavedRequest if it matches the current request. The functionality for matching and returning the wrapper is contained in the RequestCache method though.

This commit is contained in:
Luke Taylor 2009-07-20 22:34:40 +00:00
parent efd1dbf54a
commit f404bb3d74
18 changed files with 452 additions and 394 deletions

View File

@ -33,6 +33,7 @@ abstract class FilterChainOrder {
public static final int LOGIN_PAGE_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++; public static final int LOGIN_PAGE_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++;
public static final int DIGEST_PROCESSING_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++; public static final int DIGEST_PROCESSING_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++;
public static final int BASIC_PROCESSING_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++; public static final int BASIC_PROCESSING_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++;
public static final int REQUEST_CACHE_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++;
public static final int SERVLET_API_SUPPORT_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++; public static final int SERVLET_API_SUPPORT_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++;
public static final int REMEMBER_ME_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++; public static final int REMEMBER_ME_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++;
public static final int ANONYMOUS_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++; public static final int ANONYMOUS_FILTER = FILTER_CHAIN_FIRST + INTERVAL * i++;

View File

@ -3,6 +3,7 @@ package org.springframework.security.config.http;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.factory.xml.ParserContext; import org.springframework.beans.factory.xml.ParserContext;
@ -36,16 +37,18 @@ public class FormLoginBeanDefinitionParser {
private static final String ATT_SUCCESS_HANDLER_REF = "authentication-success-handler-ref"; private static final String ATT_SUCCESS_HANDLER_REF = "authentication-success-handler-ref";
private static final String ATT_FAILURE_HANDLER_REF = "authentication-failure-handler-ref"; private static final String ATT_FAILURE_HANDLER_REF = "authentication-failure-handler-ref";
private String defaultLoginProcessingUrl; private final String defaultLoginProcessingUrl;
private String filterClassName; private final String filterClassName;
private final BeanReference requestCache;
private RootBeanDefinition filterBean; private RootBeanDefinition filterBean;
private RootBeanDefinition entryPointBean; private RootBeanDefinition entryPointBean;
private String loginPage; private String loginPage;
FormLoginBeanDefinitionParser(String defaultLoginProcessingUrl, String filterClassName) { FormLoginBeanDefinitionParser(String defaultLoginProcessingUrl, String filterClassName, BeanReference requestCache) {
this.defaultLoginProcessingUrl = defaultLoginProcessingUrl; this.defaultLoginProcessingUrl = defaultLoginProcessingUrl;
this.filterClassName = filterClassName; this.filterClassName = filterClassName;
this.requestCache = requestCache;
} }
public BeanDefinition parse(Element elt, ParserContext pc, RootBeanDefinition sfpf) { public BeanDefinition parse(Element elt, ParserContext pc, RootBeanDefinition sfpf) {
@ -114,6 +117,7 @@ public class FormLoginBeanDefinitionParser {
if ("true".equals(alwaysUseDefault)) { if ("true".equals(alwaysUseDefault)) {
successHandler.addPropertyValue("alwaysUseDefaultTargetUrl", Boolean.TRUE); successHandler.addPropertyValue("alwaysUseDefaultTargetUrl", Boolean.TRUE);
} }
successHandler.addPropertyValue("requestCache", requestCache);
successHandler.addPropertyValue("defaultTargetUrl", StringUtils.hasText(defaultTargetUrl) ? defaultTargetUrl : DEF_FORM_LOGIN_TARGET_URL); successHandler.addPropertyValue("defaultTargetUrl", StringUtils.hasText(defaultTargetUrl) ? defaultTargetUrl : DEF_FORM_LOGIN_TARGET_URL);
filterBuilder.addPropertyValue("authenticationSuccessHandler", successHandler.getBeanDefinition()); filterBuilder.addPropertyValue("authenticationSuccessHandler", successHandler.getBeanDefinition());
} }

View File

@ -63,6 +63,8 @@ import org.springframework.security.web.authentication.www.BasicProcessingFilter
import org.springframework.security.web.authentication.www.BasicProcessingFilterEntryPoint; import org.springframework.security.web.authentication.www.BasicProcessingFilterEntryPoint;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextPersistenceFilter; import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCacheAwareFilter;
import org.springframework.security.web.session.SessionFixationProtectionFilter; import org.springframework.security.web.session.SessionFixationProtectionFilter;
import org.springframework.security.web.util.AntUrlPathMatcher; import org.springframework.security.web.util.AntUrlPathMatcher;
import org.springframework.security.web.util.RegexUrlPathMatcher; import org.springframework.security.web.util.RegexUrlPathMatcher;
@ -204,8 +206,11 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
DomUtils.getChildElementByTagName(element, Elements.PORT_MAPPINGS), pc); DomUtils.getChildElementByTagName(element, Elements.PORT_MAPPINGS), pc);
RootBeanDefinition rememberMeFilter = createRememberMeFilter(element, pc, authenticationManager); RootBeanDefinition rememberMeFilter = createRememberMeFilter(element, pc, authenticationManager);
BeanDefinition anonFilter = createAnonymousFilter(element, pc); BeanDefinition anonFilter = createAnonymousFilter(element, pc);
BeanReference requestCache = createRequestCache(element, pc, allowSessionCreation);
BeanDefinition requestCacheAwareFilter = new RootBeanDefinition(RequestCacheAwareFilter.class);
requestCacheAwareFilter.getPropertyValues().addPropertyValue("requestCache", requestCache);
BeanDefinition etf = createExceptionTranslationFilter(element, pc, allowSessionCreation); BeanDefinition etf = createExceptionTranslationFilter(element, pc, requestCache);
RootBeanDefinition sfpf = createSessionFixationProtectionFilter(pc, element.getAttribute(ATT_SESSION_FIXATION_PROTECTION), RootBeanDefinition sfpf = createSessionFixationProtectionFilter(pc, element.getAttribute(ATT_SESSION_FIXATION_PROTECTION),
sessionRegistryRef); sessionRegistryRef);
BeanDefinition fsi = createFilterSecurityInterceptor(element, pc, matcher, convertPathsToLowerCase, authenticationManager); BeanDefinition fsi = createFilterSecurityInterceptor(element, pc, matcher, convertPathsToLowerCase, authenticationManager);
@ -223,9 +228,9 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
final FilterAndEntryPoint basic = createBasicFilter(element, pc, autoConfig, authenticationManager); final FilterAndEntryPoint basic = createBasicFilter(element, pc, autoConfig, authenticationManager);
final FilterAndEntryPoint form = createFormLoginFilter(element, pc, autoConfig, allowSessionCreation, final FilterAndEntryPoint form = createFormLoginFilter(element, pc, autoConfig, allowSessionCreation,
sfpf, authenticationManager); sfpf, authenticationManager, requestCache);
final FilterAndEntryPoint openID = createOpenIDLoginFilter(element, pc, autoConfig, allowSessionCreation, final FilterAndEntryPoint openID = createOpenIDLoginFilter(element, pc, autoConfig, allowSessionCreation,
sfpf, authenticationManager); sfpf, authenticationManager, requestCache);
String rememberMeServicesId = null; String rememberMeServicesId = null;
if (rememberMeFilter != null) { if (rememberMeFilter != null) {
@ -298,6 +303,8 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
unorderedFilterChain.add(new OrderDecorator(basic.filter, BASIC_PROCESSING_FILTER)); unorderedFilterChain.add(new OrderDecorator(basic.filter, BASIC_PROCESSING_FILTER));
} }
unorderedFilterChain.add(new OrderDecorator(requestCacheAwareFilter, REQUEST_CACHE_FILTER));
if (servApiFilter != null) { if (servApiFilter != null) {
unorderedFilterChain.add(new OrderDecorator(servApiFilter, SERVLET_API_SUPPORT_FILTER)); unorderedFilterChain.add(new OrderDecorator(servApiFilter, SERVLET_API_SUPPORT_FILTER));
} }
@ -751,10 +758,21 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
return new RuntimeBeanReference(id); return new RuntimeBeanReference(id);
} }
private BeanDefinition createExceptionTranslationFilter(Element element, ParserContext pc, boolean allowSessionCreation) { private BeanReference createRequestCache(Element element, ParserContext pc, boolean allowSessionCreation) {
BeanDefinitionBuilder requestCache = BeanDefinitionBuilder.rootBeanDefinition(HttpSessionRequestCache.class);
requestCache.addPropertyValue("createSessionAllowed", Boolean.valueOf(allowSessionCreation));
BeanDefinition bean = requestCache.getBeanDefinition();
String id = pc.getReaderContext().registerWithGeneratedName(bean);
pc.registerBeanComponent(new BeanComponentDefinition(bean, id));
return new RuntimeBeanReference(id);
}
private BeanDefinition createExceptionTranslationFilter(Element element, ParserContext pc, BeanReference requestCache) {
BeanDefinitionBuilder exceptionTranslationFilterBuilder BeanDefinitionBuilder exceptionTranslationFilterBuilder
= BeanDefinitionBuilder.rootBeanDefinition(ExceptionTranslationFilter.class); = BeanDefinitionBuilder.rootBeanDefinition(ExceptionTranslationFilter.class);
exceptionTranslationFilterBuilder.addPropertyValue("createSessionAllowed", Boolean.valueOf(allowSessionCreation));
exceptionTranslationFilterBuilder.addPropertyValue("accessDeniedHandler", createAccessDeniedHandler(element, pc)); exceptionTranslationFilterBuilder.addPropertyValue("accessDeniedHandler", createAccessDeniedHandler(element, pc));
@ -911,7 +929,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
} }
private FilterAndEntryPoint createFormLoginFilter(Element element, ParserContext pc, boolean autoConfig, private FilterAndEntryPoint createFormLoginFilter(Element element, ParserContext pc, boolean autoConfig,
boolean allowSessionCreation, RootBeanDefinition sfpf, BeanReference authManager) { boolean allowSessionCreation, RootBeanDefinition sfpf, BeanReference authManager, BeanReference requestCache) {
RootBeanDefinition formLoginFilter = null; RootBeanDefinition formLoginFilter = null;
RootBeanDefinition formLoginEntryPoint = null; RootBeanDefinition formLoginEntryPoint = null;
@ -919,7 +937,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
if (formLoginElt != null || autoConfig) { if (formLoginElt != null || autoConfig) {
FormLoginBeanDefinitionParser parser = new FormLoginBeanDefinitionParser("/j_spring_security_check", FormLoginBeanDefinitionParser parser = new FormLoginBeanDefinitionParser("/j_spring_security_check",
AUTHENTICATION_PROCESSING_FILTER_CLASS); AUTHENTICATION_PROCESSING_FILTER_CLASS, requestCache);
parser.parse(formLoginElt, pc, sfpf); parser.parse(formLoginElt, pc, sfpf);
formLoginFilter = parser.getFilterBean(); formLoginFilter = parser.getFilterBean();
@ -935,14 +953,14 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
} }
private FilterAndEntryPoint createOpenIDLoginFilter(Element element, ParserContext pc, boolean autoConfig, private FilterAndEntryPoint createOpenIDLoginFilter(Element element, ParserContext pc, boolean autoConfig,
boolean allowSessionCreation, RootBeanDefinition sfpf, BeanReference authManager) { boolean allowSessionCreation, RootBeanDefinition sfpf, BeanReference authManager, BeanReference requestCache) {
Element openIDLoginElt = DomUtils.getChildElementByTagName(element, Elements.OPENID_LOGIN); Element openIDLoginElt = DomUtils.getChildElementByTagName(element, Elements.OPENID_LOGIN);
RootBeanDefinition openIDFilter = null; RootBeanDefinition openIDFilter = null;
RootBeanDefinition openIDEntryPoint = null; RootBeanDefinition openIDEntryPoint = null;
if (openIDLoginElt != null) { if (openIDLoginElt != null) {
FormLoginBeanDefinitionParser parser = new FormLoginBeanDefinitionParser("/j_spring_openid_security_check", FormLoginBeanDefinitionParser parser = new FormLoginBeanDefinitionParser("/j_spring_openid_security_check",
OPEN_ID_AUTHENTICATION_PROCESSING_FILTER_CLASS); OPEN_ID_AUTHENTICATION_PROCESSING_FILTER_CLASS, requestCache);
parser.parse(openIDLoginElt, pc, sfpf); parser.parse(openIDLoginElt, pc, sfpf);
openIDFilter = parser.getFilterBean(); openIDFilter = parser.getFilterBean();

View File

@ -69,6 +69,7 @@ import org.springframework.security.web.authentication.ui.DefaultLoginPageGenera
import org.springframework.security.web.authentication.www.BasicProcessingFilter; import org.springframework.security.web.authentication.www.BasicProcessingFilter;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextPersistenceFilter; import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.savedrequest.RequestCacheAwareFilter;
import org.springframework.security.web.session.SessionFixationProtectionFilter; import org.springframework.security.web.session.SessionFixationProtectionFilter;
import org.springframework.security.web.wrapper.SecurityContextHolderAwareRequestFilter; import org.springframework.security.web.wrapper.SecurityContextHolderAwareRequestFilter;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
@ -78,7 +79,7 @@ import org.springframework.util.ReflectionUtils;
* @version $Id$ * @version $Id$
*/ */
public class HttpSecurityBeanDefinitionParserTests { public class HttpSecurityBeanDefinitionParserTests {
private static final int AUTO_CONFIG_FILTERS = 10; private static final int AUTO_CONFIG_FILTERS = 11;
private AbstractXmlApplicationContext appContext; private AbstractXmlApplicationContext appContext;
@After @After
@ -132,6 +133,7 @@ public class HttpSecurityBeanDefinitionParserTests {
assertTrue(authProcFilter instanceof UsernamePasswordAuthenticationProcessingFilter); assertTrue(authProcFilter instanceof UsernamePasswordAuthenticationProcessingFilter);
assertTrue(filters.next() instanceof DefaultLoginPageGeneratingFilter); assertTrue(filters.next() instanceof DefaultLoginPageGeneratingFilter);
assertTrue(filters.next() instanceof BasicProcessingFilter); assertTrue(filters.next() instanceof BasicProcessingFilter);
assertTrue(filters.next() instanceof RequestCacheAwareFilter);
assertTrue(filters.next() instanceof SecurityContextHolderAwareRequestFilter); assertTrue(filters.next() instanceof SecurityContextHolderAwareRequestFilter);
assertTrue(filters.next() instanceof AnonymousProcessingFilter); assertTrue(filters.next() instanceof AnonymousProcessingFilter);
assertTrue(filters.next() instanceof ExceptionTranslationFilter); assertTrue(filters.next() instanceof ExceptionTranslationFilter);
@ -209,7 +211,7 @@ public class HttpSecurityBeanDefinitionParserTests {
"<http>" + "<http>" +
" <form-login />" + " <form-login />" +
"</http>" + AUTH_PROVIDER_XML); "</http>" + AUTH_PROVIDER_XML);
assertThat(getFilters("/anything").get(4), instanceOf(AnonymousProcessingFilter.class)); assertThat(getFilters("/anything").get(5), instanceOf(AnonymousProcessingFilter.class));
} }
@Test @Test
@ -219,7 +221,7 @@ public class HttpSecurityBeanDefinitionParserTests {
" <form-login />" + " <form-login />" +
" <anonymous enabled='false'/>" + " <anonymous enabled='false'/>" +
"</http>" + AUTH_PROVIDER_XML); "</http>" + AUTH_PROVIDER_XML);
assertThat(getFilters("/anything").get(4), not(instanceOf(AnonymousProcessingFilter.class))); assertThat(getFilters("/anything").get(5), not(instanceOf(AnonymousProcessingFilter.class)));
} }

View File

@ -30,10 +30,9 @@ import org.springframework.security.authentication.InsufficientAuthenticationExc
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.PortResolver;
import org.springframework.security.web.PortResolverImpl;
import org.springframework.security.web.SpringSecurityFilter; import org.springframework.security.web.SpringSecurityFilter;
import org.springframework.security.web.savedrequest.SavedRequest; import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.ThrowableAnalyzer; import org.springframework.security.web.util.ThrowableAnalyzer;
import org.springframework.security.web.util.ThrowableCauseExtractor; import org.springframework.security.web.util.ThrowableCauseExtractor;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -60,8 +59,9 @@ import org.springframework.util.Assert;
* should commence the authentication process if an * should commence the authentication process if an
* <code>AuthenticationException</code> is detected. Note that this may also * <code>AuthenticationException</code> is detected. Note that this may also
* switch the current protocol from http to https for an SSL login.</li> * switch the current protocol from http to https for an SSL login.</li>
* <li><code>portResolver</code> is used to determine the "real" port that a * <li><tt>requestCache</tt> determines the strategy used to save a request during the authentication process in order
* request was received on.</li> * that it may be retrieved and reused once the user has authenticated. The default implementation is
* {@link HttpSessionRequestCache}.</li>
* </ul> * </ul>
* *
* @author Ben Alex * @author Ben Alex
@ -75,18 +75,16 @@ public class ExceptionTranslationFilter extends SpringSecurityFilter implements
private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
private AuthenticationEntryPoint authenticationEntryPoint; private AuthenticationEntryPoint authenticationEntryPoint;
private AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl(); private AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl();
private PortResolver portResolver = new PortResolverImpl(); // private PortResolver portResolver = new PortResolverImpl();
private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer(); private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
private boolean createSessionAllowed = true;
private boolean justUseSavedRequestOnGet; private RequestCache requestCache = new HttpSessionRequestCache();
//~ Methods ======================================================================================================== //~ Methods ========================================================================================================
public void afterPropertiesSet() throws Exception { public void afterPropertiesSet() throws Exception {
Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint must be specified"); Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint must be specified");
Assert.notNull(portResolver, "portResolver must be specified"); // Assert.notNull(portResolver, "portResolver must be specified");
Assert.notNull(authenticationTrustResolver, "authenticationTrustResolver must be specified");
Assert.notNull(throwableAnalyzer, "throwableAnalyzer must be specified");
} }
public void doFilterHttp(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, public void doFilterHttp(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException,
@ -133,14 +131,10 @@ public class ExceptionTranslationFilter extends SpringSecurityFilter implements
return authenticationEntryPoint; return authenticationEntryPoint;
} }
public AuthenticationTrustResolver getAuthenticationTrustResolver() { protected AuthenticationTrustResolver getAuthenticationTrustResolver() {
return authenticationTrustResolver; return authenticationTrustResolver;
} }
public PortResolver getPortResolver() {
return portResolver;
}
private void handleException(HttpServletRequest request, HttpServletResponse response, FilterChain chain, private void handleException(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
RuntimeException exception) throws IOException, ServletException { RuntimeException exception) throws IOException, ServletException {
if (exception instanceof AuthenticationException) { if (exception instanceof AuthenticationException) {
@ -171,48 +165,16 @@ public class ExceptionTranslationFilter extends SpringSecurityFilter implements
} }
} }
/**
* If <code>true</code>, indicates that <code>ExceptionTranslationFilter</code> is permitted to store the target
* URL and exception information in a new <code>HttpSession</code> (the default).
* In situations where you do not wish to unnecessarily create <code>HttpSession</code>s - because the user agent
* will know the failed URL, such as with BASIC or Digest authentication - you may wish to set this property to
* <code>false</code>.
* <p>
* Remember to also set
* {@link org.springframework.security.web.context.HttpSessionSecurityContextRepository#setAllowSessionCreation(boolean)}
* to <code>false</code> if you set this property to <code>false</code>.
*
* @return <code>true</code> if the <code>HttpSession</code> will be
* used to store information about the failed request, <code>false</code>
* if the <code>HttpSession</code> will not be used
*/
public boolean isCreateSessionAllowed() {
return createSessionAllowed;
}
protected void sendStartAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, protected void sendStartAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
AuthenticationException reason) throws ServletException, IOException { AuthenticationException reason) throws ServletException, IOException {
// SEC-112: Clear the SecurityContextHolder's Authentication, as the // SEC-112: Clear the SecurityContextHolder's Authentication, as the
// existing Authentication is no longer considered valid // existing Authentication is no longer considered valid
SecurityContextHolder.getContext().setAuthentication(null); SecurityContextHolder.getContext().setAuthentication(null);
saveRequestIfAllowed(request); requestCache.saveRequest(request, response);
logger.debug("Calling Authentication entry point."); logger.debug("Calling Authentication entry point.");
authenticationEntryPoint.commence(request, response, reason); authenticationEntryPoint.commence(request, response, reason);
} }
private void saveRequestIfAllowed(HttpServletRequest request) {
if (!justUseSavedRequestOnGet || "GET".equals(request.getMethod())) {
SavedRequest savedRequest = new SavedRequest(request, portResolver);
if (createSessionAllowed || request.getSession(false) != null) {
// Store the HTTP request itself. Used by AbstractAuthenticationProcessingFilter
// for redirection after successful authentication (SEC-29)
request.getSession().setAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest);
logger.debug("SavedRequest added to Session: " + savedRequest);
}
}
}
public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) { public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) {
Assert.notNull(accessDeniedHandler, "AccessDeniedHandler required"); Assert.notNull(accessDeniedHandler, "AccessDeniedHandler required");
this.accessDeniedHandler = accessDeniedHandler; this.accessDeniedHandler = accessDeniedHandler;
@ -223,27 +185,30 @@ public class ExceptionTranslationFilter extends SpringSecurityFilter implements
} }
public void setAuthenticationTrustResolver(AuthenticationTrustResolver authenticationTrustResolver) { public void setAuthenticationTrustResolver(AuthenticationTrustResolver authenticationTrustResolver) {
Assert.notNull(authenticationTrustResolver, "authenticationTrustResolver must not be null");
this.authenticationTrustResolver = authenticationTrustResolver; this.authenticationTrustResolver = authenticationTrustResolver;
} }
public void setCreateSessionAllowed(boolean createSessionAllowed) { // public void setCreateSessionAllowed(boolean createSessionAllowed) {
this.createSessionAllowed = createSessionAllowed; // this.createSessionAllowed = createSessionAllowed;
} // }
public void setPortResolver(PortResolver portResolver) { // public void setPortResolver(PortResolver portResolver) {
this.portResolver = portResolver; // this.portResolver = portResolver;
} // }
public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) { public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
Assert.notNull(throwableAnalyzer, "throwableAnalyzer must not be null");
this.throwableAnalyzer = throwableAnalyzer; this.throwableAnalyzer = throwableAnalyzer;
} }
/** /**
* If <code>true</code>, will only use <code>SavedRequest</code> to determine the target URL on successful * The RequestCache implementation used to store the current request before starting authentication.
* authentication if the request that caused the authentication request was a GET. Defaults to false. * Defaults to an {@link HttpSessionRequestCache}.
*/ */
public void setJustUseSavedRequestOnGet(boolean justUseSavedRequestOnGet) { public void setRequestCache(RequestCache requestCache) {
this.justUseSavedRequestOnGet = justUseSavedRequestOnGet; Assert.notNull(requestCache, "requestCache cannot be null");
this.requestCache = requestCache;
} }
/** /**

View File

@ -5,13 +5,15 @@ import java.io.IOException;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.web.access.ExceptionTranslationFilter; import org.springframework.security.web.access.ExceptionTranslationFilter;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.savedrequest.SavedRequest; import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.security.web.util.RedirectUtils; import org.springframework.security.web.util.RedirectUtils;
import org.springframework.security.web.wrapper.SavedRequestAwareWrapper;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
/** /**
@ -33,13 +35,13 @@ import org.springframework.util.StringUtils;
* Any <tt>SavedRequest</tt> will again be removed. * Any <tt>SavedRequest</tt> will again be removed.
* </li> * </li>
* <li> * <li>
* If a {@link SavedRequest} is found in the session (as set by the {@link ExceptionTranslationFilter} to record * If a {@link SavedRequest} is found in the <tt>RequestCache</tt> (as set by the {@link ExceptionTranslationFilter} to
* the original destination before the authentication process commenced), a redirect will be performed to the * record the original destination before the authentication process commenced), a redirect will be performed to the
* Url of that original destination. The <tt>SavedRequest</tt> object will remain in the session and be picked up * Url of that original destination. The <tt>SavedRequest</tt> object will remain cached and be picked up
* when the redirected request is received (See {@link SavedRequestAwareWrapper}). * when the redirected request is received (See {@link SavedRequestAwareWrapper}).
* </li> * </li>
* <li> * <li>
* If no <tt>SavedRequest</tt> is found in the session, it will delegate to the base class. * If no <tt>SavedRequest</tt> is found, it will delegate to the base class.
* </li> * </li>
* </ul> * </ul>
* *
@ -49,11 +51,14 @@ import org.springframework.util.StringUtils;
* @since 3.0 * @since 3.0
*/ */
public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuthenticationSuccessHandler { public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuthenticationSuccessHandler {
protected final Log logger = LogFactory.getLog(this.getClass());
private RequestCache requestCache = new HttpSessionRequestCache();
@Override @Override
public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response,
Authentication authentication) throws ServletException, IOException { Authentication authentication) throws ServletException, IOException {
SavedRequest savedRequest = getSavedRequest(request); SavedRequest savedRequest = requestCache.getRequest(request, response);
if (savedRequest == null) { if (savedRequest == null) {
super.onAuthenticationSuccess(request, response, authentication); super.onAuthenticationSuccess(request, response, authentication);
@ -62,7 +67,7 @@ public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuth
} }
if (isAlwaysUseDefaultTargetUrl() || StringUtils.hasText(request.getParameter(getTargetUrlParameter()))) { if (isAlwaysUseDefaultTargetUrl() || StringUtils.hasText(request.getParameter(getTargetUrlParameter()))) {
removeSavedRequest(request); requestCache.removeRequest(request, response);
super.onAuthenticationSuccess(request, response, authentication); super.onAuthenticationSuccess(request, response, authentication);
return; return;
@ -74,22 +79,22 @@ public class SavedRequestAwareAuthenticationSuccessHandler extends SimpleUrlAuth
RedirectUtils.sendRedirect(request, response, targetUrl, isUseRelativeContext()); RedirectUtils.sendRedirect(request, response, targetUrl, isUseRelativeContext());
} }
private SavedRequest getSavedRequest(HttpServletRequest request) { // private SavedRequest getSavedRequest(HttpServletRequest request) {
HttpSession session = request.getSession(false); // HttpSession session = request.getSession(false);
//
if (session != null) { // if (session != null) {
return (SavedRequest) session.getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); // return (SavedRequest) session.getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY);
} // }
//
return null; // return null;
} // }
//
private void removeSavedRequest(HttpServletRequest request) { // private void removeSavedRequest(HttpServletRequest request) {
HttpSession session = request.getSession(false); // HttpSession session = request.getSession(false);
//
if (session != null) { // if (session != null) {
logger.debug("Removing SavedRequest from session if present"); // logger.debug("Removing SavedRequest from session if present");
session.removeAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY); // session.removeAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY);
} // }
} // }
} }

View File

@ -0,0 +1,99 @@
package org.springframework.security.web.savedrequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.security.web.PortResolver;
import org.springframework.security.web.PortResolverImpl;
/**
* <tt>RequestCache</tt> which stores the SavedRequest in the HttpSession.
*
* @author Luke Taylor
* @version $Id$
* @since 3.0
*/
public class HttpSessionRequestCache implements RequestCache {
protected final Log logger = LogFactory.getLog(this.getClass());
private PortResolver portResolver = new PortResolverImpl();
private boolean createSessionAllowed = true;
private boolean justUseSavedRequestOnGet;
/**
* Stores the current request, provided the configuration properties allow it.
*/
public void saveRequest(HttpServletRequest request, HttpServletResponse response) {
if (!justUseSavedRequestOnGet || "GET".equals(request.getMethod())) {
SavedRequest savedRequest = new SavedRequest(request, portResolver);
if (createSessionAllowed || request.getSession(false) != null) {
// Store the HTTP request itself. Used by AbstractAuthenticationProcessingFilter
// for redirection after successful authentication (SEC-29)
request.getSession().setAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest);
logger.debug("SavedRequest added to Session: " + savedRequest);
}
}
}
public SavedRequest getRequest(HttpServletRequest currentRequest, HttpServletResponse response) {
HttpSession session = currentRequest.getSession(false);
if (session != null) {
return (SavedRequest) session.getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY);
}
return null;
}
public void removeRequest(HttpServletRequest currentRequest, HttpServletResponse response) {
HttpSession session = currentRequest.getSession(false);
if (session != null) {
logger.debug("Removing SavedRequest from session if present");
session.removeAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY);
}
}
public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) {
SavedRequest saved = getRequest(request, response);
if (saved == null) {
return null;
}
if (!saved.doesRequestMatch(request, portResolver)) {
logger.debug("saved request doesn't match");
return null;
}
return new SavedRequestAwareWrapper(saved, request);
}
/**
* If <code>true</code>, will only use <code>SavedRequest</code> to determine the target URL on successful
* authentication if the request that caused the authentication request was a GET. Defaults to false.
*/
public void setJustUseSavedRequestOnGet(boolean justUseSavedRequestOnGet) {
this.justUseSavedRequestOnGet = justUseSavedRequestOnGet;
}
/**
* If <code>true</code>, indicates that it is permitted to store the target
* URL and exception information in a new <code>HttpSession</code> (the default).
* In situations where you do not wish to unnecessarily create <code>HttpSession</code>s - because the user agent
* will know the failed URL, such as with BASIC or Digest authentication - you may wish to set this property to
* <code>false</code>.
*/
public void setCreateSessionAllowed(boolean createSessionAllowed) {
this.createSessionAllowed = createSessionAllowed;
}
public void setPortResolver(PortResolver portResolver) {
this.portResolver = portResolver;
}
}

View File

@ -0,0 +1,31 @@
package org.springframework.security.web.savedrequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Null implementation of <tt>RequestCache</tt>.
* Typically used when creation of a session is not desired.
*
* @author Luke Taylor
* @version $Id$
* @since 3.0
*/
public class NullRequestCache implements RequestCache {
public SavedRequest getRequest(HttpServletRequest request, HttpServletResponse response) {
return null;
}
public void removeRequest(HttpServletRequest request, HttpServletResponse response) {
}
public void saveRequest(HttpServletRequest request, HttpServletResponse response) {
}
public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) {
return null;
}
}

View File

@ -0,0 +1,47 @@
package org.springframework.security.web.savedrequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Implements "saved request" logic, allowing a single request to be retrieved and restarted after redirecting to
* an authentication mechanism.
*
* @author Luke Taylor
* @version $Id$
* @since 3.0
*/
public interface RequestCache {
/**
* Caches the current request for later retrieval, once authentication has taken place.
* Used by <tt>ExceptionTranslationFilter</tt>.
*
* @param request the request to be stored
*/
void saveRequest(HttpServletRequest request, HttpServletResponse response);
/**
* Returns the saved request, leaving it cached.
* @param currentRequest the current
* @return the saved request which was previously cached, or null if there is none.
*/
SavedRequest getRequest(HttpServletRequest request, HttpServletResponse response);
/**
* Returns a wrapper around the saved request, if it matches the current request. The saved request should
* be removed from the cache.
*
* @param request
* @param response
* @return the wrapped save request, if it matches the
*/
HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response);
/**
* Removes and returns the cached request
* @param currentRequest
*/
void removeRequest(HttpServletRequest request, HttpServletResponse response);
}

View File

@ -0,0 +1,41 @@
package org.springframework.security.web.savedrequest;
import java.io.IOException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.security.web.SpringSecurityFilter;
/**
* Responsible for reconstituting the saved request if one is cached and it matches the current request.
* <p>
* It will call {@link RequestCache#getMatchingRequest(HttpServletRequest, HttpServletResponse) getMatchingRequest}
* on the configured <tt>RequestCache</tt>. If the method returns a value (a wrapper of the saved request), it will
* pass this to the filter chain's <tt>doFilter</tt> method.
* If null is returned by the cache, the original request is used and the filter has no effect.
*
* @author Luke Taylor
* @version $Id$
* @since 3.0
*/
public class RequestCacheAwareFilter extends SpringSecurityFilter {
private RequestCache requestCache = new HttpSessionRequestCache();
@Override
protected void doFilterHttp(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest wrappedSavedRequest = requestCache.getMatchingRequest(request, response);
chain.doFilter(wrappedSavedRequest == null ? request : wrappedSavedRequest, response);
}
public void setRequestCache(RequestCache requestCache) {
this.requestCache = requestCache;
}
}

View File

@ -35,7 +35,7 @@ import java.util.TreeMap;
/** /**
* Represents central information from a <code>HttpServletRequest</code>.<p>This class is used by {@link * Represents central information from a <code>HttpServletRequest</code>.<p>This class is used by {@link
* org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter} and {@link org.springframework.security.web.wrapper.SavedRequestAwareWrapper} to * org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter} and {@link org.springframework.security.web.savedrequest.SavedRequestAwareWrapper} to
* reproduce the request after successful authentication. An instance of this class is stored at the time of an * reproduce the request after successful authentication. An instance of this class is stored at the time of an
* authentication exception by {@link org.springframework.security.web.access.ExceptionTranslationFilter}.</p> * authentication exception by {@link org.springframework.security.web.access.ExceptionTranslationFilter}.</p>
* <p><em>IMPLEMENTATION NOTE</em>: It is assumed that this object is accessed only from the context of a single * <p><em>IMPLEMENTATION NOTE</em>: It is assumed that this object is accessed only from the context of a single

View File

@ -13,7 +13,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.springframework.security.web.wrapper; package org.springframework.security.web.savedrequest;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.ArrayList; import java.util.ArrayList;
@ -30,14 +30,10 @@ import java.util.TimeZone;
import javax.servlet.http.Cookie; import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpServletRequestWrapper;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.security.web.PortResolver;
import org.springframework.security.web.savedrequest.Enumerator;
import org.springframework.security.web.savedrequest.FastHttpDateFormat;
import org.springframework.security.web.savedrequest.SavedRequest;
/** /**
@ -47,16 +43,18 @@ import org.springframework.security.web.savedrequest.SavedRequest;
* Nevertheless, the important data from the original request is emulated and this should prove * Nevertheless, the important data from the original request is emulated and this should prove
* adequate for most purposes (in particular standard HTTP GET and POST operations).</p> * adequate for most purposes (in particular standard HTTP GET and POST operations).</p>
* *
* <p>Added into a request by {@link org.springframework.security.web.wrapper.SecurityContextHolderAwareRequestFilter}.</p> * <p>
* Added into a request by {@link org.springframework.security.web.savedrequest.RequestCacheAwareFilter}.
* </p>
* *
* * TODO: savedRequest cannot now be null, so convert the tests to reflect this and remove the null checks.
* @see SecurityContextHolderAwareRequestFilter
* *
* @author Andrey Grebnev * @author Andrey Grebnev
* @author Ben Alex * @author Ben Alex
* @author Luke Taylor
* @version $Id$ * @version $Id$
*/ */
public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestWrapper { class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
//~ Static fields/initializers ===================================================================================== //~ Static fields/initializers =====================================================================================
protected static final Log logger = LogFactory.getLog(SavedRequestAwareWrapper.class); protected static final Log logger = LogFactory.getLog(SavedRequestAwareWrapper.class);
@ -77,28 +75,9 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
//~ Constructors =================================================================================================== //~ Constructors ===================================================================================================
public SavedRequestAwareWrapper(HttpServletRequest request, PortResolver portResolver, String rolePrefix) { public SavedRequestAwareWrapper(SavedRequest saved, HttpServletRequest request) {
super(request, portResolver, rolePrefix); super(request);
HttpSession session = request.getSession(false);
if (session == null) {
if (logger.isDebugEnabled()) {
logger.debug("Wrapper not replaced; no session available for SavedRequest extraction");
}
return;
}
SavedRequest saved = (SavedRequest) session.getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY);
if ((saved != null) && saved.doesRequestMatch(request, portResolver)) {
if (logger.isDebugEnabled()) {
logger.debug("Wrapper replaced; SavedRequest was: " + saved);
}
savedRequest = saved; savedRequest = saved;
session.removeAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY);
formats[0] = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US); formats[0] = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
formats[1] = new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz", Locale.US); formats[1] = new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz", Locale.US);
@ -107,11 +86,6 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
formats[0].setTimeZone(GMT_ZONE); formats[0].setTimeZone(GMT_ZONE);
formats[1].setTimeZone(GMT_ZONE); formats[1].setTimeZone(GMT_ZONE);
formats[2].setTimeZone(GMT_ZONE); formats[2].setTimeZone(GMT_ZONE);
} else {
if (logger.isDebugEnabled()) {
logger.debug("Wrapper not replaced; SavedRequest was: " + saved);
}
}
} }
//~ Methods ======================================================================================================== //~ Methods ========================================================================================================
@ -120,18 +94,17 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
public Cookie[] getCookies() { public Cookie[] getCookies() {
if (savedRequest == null) { if (savedRequest == null) {
return super.getCookies(); return super.getCookies();
} else { }
List<Cookie> cookies = savedRequest.getCookies(); List<Cookie> cookies = savedRequest.getCookies();
return cookies.toArray(new Cookie[cookies.size()]); return cookies.toArray(new Cookie[cookies.size()]);
} }
}
@Override @Override
public long getDateHeader(String name) { public long getDateHeader(String name) {
if (savedRequest == null) { if (savedRequest == null) {
return super.getDateHeader(name); return super.getDateHeader(name);
} else { }
String value = getHeader(name); String value = getHeader(name);
if (value == null) { if (value == null) {
@ -147,13 +120,13 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
throw new IllegalArgumentException(value); throw new IllegalArgumentException(value);
} }
}
@Override @Override
public String getHeader(String name) { public String getHeader(String name) {
if (savedRequest == null) { if (savedRequest == null) {
return super.getHeader(name); return super.getHeader(name);
} else { }
String header = null; String header = null;
Iterator<String> iterator = savedRequest.getHeaderValues(name); Iterator<String> iterator = savedRequest.getHeaderValues(name);
@ -165,16 +138,15 @@ public class SavedRequestAwareWrapper extends SecurityContextHolderAwareRequestW
return header; return header;
} }
}
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public Enumeration getHeaderNames() { public Enumeration getHeaderNames() {
if (savedRequest == null) { if (savedRequest == null) {
return super.getHeaderNames(); return super.getHeaderNames();
} else {
return new Enumerator<String>(savedRequest.getHeaderNames());
} }
return new Enumerator<String>(savedRequest.getHeaderNames());
} }
@Override @Override

View File

@ -16,75 +16,41 @@
package org.springframework.security.web.wrapper; package org.springframework.security.web.wrapper;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Constructor;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.springframework.security.web.PortResolver;
import org.springframework.security.web.PortResolverImpl;
import org.springframework.security.web.SpringSecurityFilter; import org.springframework.security.web.SpringSecurityFilter;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
/** /**
* A <code>Filter</code> which populates the <code>ServletRequest</code> with a new request wrapper. * A <code>Filter</code> which populates the <code>ServletRequest</code> with a request wrapper
* Several request wrappers are included with the framework. The simplest version is {@link * which implements the servlet API security methods.
* SecurityContextHolderAwareRequestWrapper}. A more complex and powerful request wrapper is
* {@link SavedRequestAwareWrapper}. The latter is also the default.
* <p> * <p>
* To modify the wrapper used, call {@link #setWrapperClass(Class)}. * The wrapper class used is {@link SecurityContextHolderAwareRequestWrapper}.
* <p>
* Any request wrapper configured for instantiation by this class must provide a public constructor that
* accepts two arguments, being a <code>HttpServletRequest</code> and a <code>PortResolver</code>.
* *
* @author Orlando Garcia Carmona * @author Orlando Garcia Carmona
* @author Ben Alex * @author Ben Alex
* @author Luke Taylor
* @version $Id$ * @version $Id$
*/ */
public class SecurityContextHolderAwareRequestFilter extends SpringSecurityFilter { public class SecurityContextHolderAwareRequestFilter extends SpringSecurityFilter {
//~ Instance fields ================================================================================================ //~ Instance fields ================================================================================================
private Class<? extends HttpServletRequest> wrapperClass = SavedRequestAwareWrapper.class;
private Constructor<? extends HttpServletRequest> constructor;
private PortResolver portResolver = new PortResolverImpl();
private String rolePrefix; private String rolePrefix;
//~ Methods ======================================================================================================== //~ Methods ========================================================================================================
public void setPortResolver(PortResolver portResolver) {
Assert.notNull(portResolver, "PortResolver required");
this.portResolver = portResolver;
}
@SuppressWarnings("unchecked")
public void setWrapperClass(Class wrapperClass) {
Assert.notNull(wrapperClass, "WrapperClass required");
Assert.isTrue(HttpServletRequest.class.isAssignableFrom(wrapperClass), "Wrapper must be a HttpServletRequest");
this.wrapperClass = wrapperClass;
}
public void setRolePrefix(String rolePrefix) { public void setRolePrefix(String rolePrefix) {
Assert.notNull(rolePrefix, "Role prefix must not be null"); Assert.notNull(rolePrefix, "Role prefix must not be null");
this.rolePrefix = rolePrefix.trim(); this.rolePrefix = rolePrefix.trim();
} }
protected void doFilterHttp(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException { protected void doFilterHttp(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
if (!wrapperClass.isAssignableFrom(request.getClass())) { throws IOException, ServletException {
try { chain.doFilter(new SecurityContextHolderAwareRequestWrapper(request, rolePrefix), response);
if (constructor == null) {
constructor = wrapperClass.getConstructor(HttpServletRequest.class, PortResolver.class, String.class);
}
request = constructor.newInstance(request, portResolver, rolePrefix);
} catch (Exception ex) {
ReflectionUtils.handleReflectionException(ex);
}
}
chain.doFilter(request, response);
} }
} }

View File

@ -16,6 +16,12 @@
package org.springframework.security.web.wrapper; package org.springframework.security.web.wrapper;
import java.security.Principal;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@ -23,25 +29,18 @@ import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.web.PortResolver;
import java.security.Principal;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
/** /**
* A Spring Security-aware <code>HttpServletRequestWrapper</code>, which uses the * A Spring Security-aware <code>HttpServletRequestWrapper</code>, which uses the
* <code>SecurityContext</code>-defined <code>Authentication</code> object for {@link * <code>SecurityContext</code>-defined <code>Authentication</code> object to implement the servlet API security
* SecurityContextHolderAwareRequestWrapper#isUserInRole(java.lang.String)} and {@link * methods {@link SecurityContextHolderAwareRequestWrapper#isUserInRole(String)} and {@link
* javax.servlet.http.HttpServletRequestWrapper#getRemoteUser()} responses. * HttpServletRequestWrapper#getRemoteUser()}.
* *
* @see SecurityContextHolderAwareRequestFilter * @see SecurityContextHolderAwareRequestFilter
* *
* @author Orlando Garcia Carmona * @author Orlando Garcia Carmona
* @author Ben Alex * @author Ben Alex
* @author Luke Taylor
* @version $Id$ * @version $Id$
*/ */
public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequestWrapper { public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequestWrapper {
@ -57,10 +56,7 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest
//~ Constructors =================================================================================================== //~ Constructors ===================================================================================================
public SecurityContextHolderAwareRequestWrapper( public SecurityContextHolderAwareRequestWrapper(HttpServletRequest request, String rolePrefix) {
HttpServletRequest request,
PortResolver portResolver,
String rolePrefix) {
super(request); super(request);
this.rolePrefix = rolePrefix; this.rolePrefix = rolePrefix;

View File

@ -15,29 +15,35 @@
package org.springframework.security.web.access; package org.springframework.security.web.access;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.*;
import java.io.IOException; import java.io.IOException;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import junit.framework.TestCase; import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.MockPortResolver; import org.springframework.security.MockPortResolver;
import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.SavedRequest; import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.security.web.util.ThrowableAnalyzer;
/** /**
* Tests {@link ExceptionTranslationFilter}. * Tests {@link ExceptionTranslationFilter}.
@ -45,11 +51,11 @@ import org.springframework.security.web.savedrequest.SavedRequest;
* @author Ben Alex * @author Ben Alex
* @version $Id$ * @version $Id$
*/ */
public class ExceptionTranslationFilterTests extends TestCase { public class ExceptionTranslationFilterTests {
//~ Methods ========================================================================================================
protected void tearDown() throws Exception { @After
super.tearDown(); @Before
public void clearContext() throws Exception {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
} }
@ -65,6 +71,7 @@ public class ExceptionTranslationFilterTests extends TestCase {
return savedRequest.getFullRequestUrl(); return savedRequest.getFullRequestUrl();
} }
@Test
public void testAccessDeniedWhenAnonymous() throws Exception { public void testAccessDeniedWhenAnonymous() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
@ -76,7 +83,8 @@ public class ExceptionTranslationFilterTests extends TestCase {
request.setRequestURI("/mycontext/secure/page.html"); request.setRequestURI("/mycontext/secure/page.html");
// Setup the FilterChain to thrown an access denied exception // Setup the FilterChain to thrown an access denied exception
MockFilterChain chain = new MockFilterChain(true, false, false, false); FilterChain fc = mock(FilterChain.class);
doThrow(new AccessDeniedException("")).when(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
// Setup SecurityContextHolder, as filter needs to check if user is // Setup SecurityContextHolder, as filter needs to check if user is
// anonymous // anonymous
@ -86,24 +94,27 @@ public class ExceptionTranslationFilterTests extends TestCase {
// Test // Test
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setAuthenticationEntryPoint(mockEntryPoint()); filter.setAuthenticationEntryPoint(mockEntryPoint());
filter.setAuthenticationTrustResolver(new AuthenticationTrustResolverImpl());
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, chain); filter.doFilter(request, response, fc);
assertEquals("/mycontext/login.jsp", response.getRedirectedUrl()); assertEquals("/mycontext/login.jsp", response.getRedirectedUrl());
assertEquals("http://www.example.com/mycontext/secure/page.html", getSavedRequestUrl(request)); assertEquals("http://www.example.com/mycontext/secure/page.html", getSavedRequestUrl(request));
} }
@Test
public void testAccessDeniedWhenNonAnonymous() throws Exception { public void testAccessDeniedWhenNonAnonymous() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setServletPath("/secure/page.html"); request.setServletPath("/secure/page.html");
// Setup the FilterChain to thrown an access denied exception // Setup the FilterChain to thrown an access denied exception
MockFilterChain chain = new MockFilterChain(true, false, false, false); FilterChain fc = mock(FilterChain.class);
doThrow(new AccessDeniedException("")).when(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
// Setup SecurityContextHolder, as filter needs to check if user is // Setup SecurityContextHolder, as filter needs to check if user is
// anonymous // anonymous
SecurityContextHolder.getContext().setAuthentication(null); SecurityContextHolder.clearContext();
// Setup a new AccessDeniedHandlerImpl that will do a "forward" // Setup a new AccessDeniedHandlerImpl that will do a "forward"
AccessDeniedHandlerImpl adh = new AccessDeniedHandlerImpl(); AccessDeniedHandlerImpl adh = new AccessDeniedHandlerImpl();
@ -115,22 +126,13 @@ public class ExceptionTranslationFilterTests extends TestCase {
filter.setAccessDeniedHandler(adh); filter.setAccessDeniedHandler(adh);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, chain); filter.doFilter(request, response, fc);
assertEquals(403, response.getStatus()); assertEquals(403, response.getStatus());
assertEquals(AccessDeniedException.class, request.getAttribute( assertEquals(AccessDeniedException.class, request.getAttribute(
AccessDeniedHandlerImpl.SPRING_SECURITY_ACCESS_DENIED_EXCEPTION_KEY).getClass()); AccessDeniedHandlerImpl.SPRING_SECURITY_ACCESS_DENIED_EXCEPTION_KEY).getClass());
} }
public void testGettersSetters() { @Test
ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setAuthenticationEntryPoint(mockEntryPoint());
assertTrue(filter.getAuthenticationEntryPoint() != null);
filter.setPortResolver(new MockPortResolver(80, 443));
assertTrue(filter.getPortResolver() != null);
}
public void testRedirectedToLoginFormAndSessionShowsOriginalTargetWhenAuthenticationException() throws Exception { public void testRedirectedToLoginFormAndSessionShowsOriginalTargetWhenAuthenticationException() throws Exception {
// Setup our HTTP request // Setup our HTTP request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
@ -142,25 +144,20 @@ public class ExceptionTranslationFilterTests extends TestCase {
request.setRequestURI("/mycontext/secure/page.html"); request.setRequestURI("/mycontext/secure/page.html");
// Setup the FilterChain to thrown an authentication failure exception // Setup the FilterChain to thrown an authentication failure exception
MockFilterChain chain = new MockFilterChain(false, true, false, false); FilterChain fc = mock(FilterChain.class);
doThrow(new BadCredentialsException("")).when(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
// Test // Test
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setAuthenticationEntryPoint(mockEntryPoint()); filter.setAuthenticationEntryPoint(mockEntryPoint());
filter.setPortResolver(new MockPortResolver(80, 443)); filter.afterPropertiesSet();
/*
* Disabled the call to afterPropertiesSet as it requires
* applicationContext to be injected before it is invoked. We do not
* have this filter configured in IOC for this test hence no
* ApplicationContext
*/
// filter.afterPropertiesSet();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, chain); filter.doFilter(request, response, fc);
assertEquals("/mycontext/login.jsp", response.getRedirectedUrl()); assertEquals("/mycontext/login.jsp", response.getRedirectedUrl());
assertEquals("http://www.example.com/mycontext/secure/page.html", getSavedRequestUrl(request)); assertEquals("http://www.example.com/mycontext/secure/page.html", getSavedRequestUrl(request));
} }
@Test
public void testRedirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException() public void testRedirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException()
throws Exception { throws Exception {
// Setup our HTTP request // Setup our HTTP request
@ -173,61 +170,52 @@ public class ExceptionTranslationFilterTests extends TestCase {
request.setRequestURI("/mycontext/secure/page.html"); request.setRequestURI("/mycontext/secure/page.html");
// Setup the FilterChain to thrown an authentication failure exception // Setup the FilterChain to thrown an authentication failure exception
MockFilterChain chain = new MockFilterChain(false, true, false, false); FilterChain fc = mock(FilterChain.class);
doThrow(new BadCredentialsException("")).when(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
// Test // Test
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setAuthenticationEntryPoint(mockEntryPoint()); filter.setAuthenticationEntryPoint(mockEntryPoint());
filter.setPortResolver(new MockPortResolver(8080, 8443)); HttpSessionRequestCache requestCache = new HttpSessionRequestCache();
/* requestCache.setPortResolver(new MockPortResolver(8080, 8443));
* Disabled the call to afterPropertiesSet as it requires filter.setRequestCache(requestCache);
* applicationContext to be injected before it is invoked. We do not filter.afterPropertiesSet();
* have this filter configured in IOC for this test hence no
* ApplicationContext
*/
// filter.afterPropertiesSet();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, chain); filter.doFilter(request, response, fc);
assertEquals("/mycontext/login.jsp", response.getRedirectedUrl()); assertEquals("/mycontext/login.jsp", response.getRedirectedUrl());
assertEquals("http://www.example.com:8080/mycontext/secure/page.html", getSavedRequestUrl(request)); assertEquals("http://www.example.com:8080/mycontext/secure/page.html", getSavedRequestUrl(request));
} }
@Test
public void testSavedRequestIsNotStoredForPostIfJustUseSaveRequestOnGetIsSet() throws Exception { public void testSavedRequestIsNotStoredForPostIfJustUseSaveRequestOnGetIsSet() throws Exception {
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setJustUseSavedRequestOnGet(true); HttpSessionRequestCache requestCache = new HttpSessionRequestCache();
requestCache.setPortResolver(new MockPortResolver(8080, 8443));
requestCache.setJustUseSavedRequestOnGet(true);
filter.setRequestCache(requestCache);
filter.setAuthenticationEntryPoint(mockEntryPoint()); filter.setAuthenticationEntryPoint(mockEntryPoint());
filter.setPortResolver(new MockPortResolver(8080, 8443));
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
MockFilterChain chain = new MockFilterChain(false, true, false, false); FilterChain fc = mock(FilterChain.class);
doThrow(new BadCredentialsException("")).when(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
request.setMethod("POST"); request.setMethod("POST");
filter.doFilter(request, new MockHttpServletResponse(), chain); filter.doFilter(request, new MockHttpServletResponse(), fc);
assertTrue(request.getSession().getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY) == null); assertTrue(request.getSession().getAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY) == null);
} }
@Test(expected=IllegalArgumentException.class)
public void testStartupDetectsMissingAuthenticationEntryPoint() throws Exception { public void testStartupDetectsMissingAuthenticationEntryPoint() throws Exception {
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setThrowableAnalyzer(mock(ThrowableAnalyzer.class));
try {
filter.afterPropertiesSet(); filter.afterPropertiesSet();
fail("Should have thrown IllegalArgumentException");
}
catch (IllegalArgumentException expected) {
assertEquals("authenticationEntryPoint must be specified", expected.getMessage());
}
} }
public void testStartupDetectsMissingPortResolver() throws Exception { @Test(expected=IllegalArgumentException.class)
public void testStartupDetectsMissingRequestCache() throws Exception {
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setAuthenticationEntryPoint(mockEntryPoint()); filter.setAuthenticationEntryPoint(mockEntryPoint());
filter.setPortResolver(null);
try { filter.setRequestCache(null);
filter.afterPropertiesSet();
fail("Should have thrown IllegalArgumentException");
}
catch (IllegalArgumentException expected) {
assertEquals("portResolver must be specified", expected.getMessage());
}
} }
public void testSuccessfulAccessGrant() throws Exception { public void testSuccessfulAccessGrant() throws Exception {
@ -235,39 +223,24 @@ public class ExceptionTranslationFilterTests extends TestCase {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setServletPath("/secure/page.html"); request.setServletPath("/secure/page.html");
// Setup the FilterChain to thrown no exceptions
MockFilterChain chain = new MockFilterChain(false, false, false, false);
// Test // Test
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setAuthenticationEntryPoint(mockEntryPoint()); filter.setAuthenticationEntryPoint(mockEntryPoint());
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, chain); filter.doFilter(request, response, mock(FilterChain.class));
}
public void testSuccessfulStartupAndShutdownDown() throws Exception {
ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.init(null);
filter.destroy();
assertTrue(true);
} }
@Test
public void testThrowIOException() throws Exception { public void testThrowIOException() throws Exception {
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setAuthenticationEntryPoint(mockEntryPoint()); filter.setAuthenticationEntryPoint(mockEntryPoint());
/* filter.afterPropertiesSet();
* Disabled the call to afterPropertiesSet as it requires FilterChain fc = mock(FilterChain.class);
* applicationContext to be injected before it is invoked. We do not doThrow(new IOException()).when(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
* have this filter configured in IOC for this test hence no
* ApplicationContext
*/
// filter.afterPropertiesSet();
try { try {
filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), new MockFilterChain(false, filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), fc);
false, false, true));
fail("Should have thrown IOException"); fail("Should have thrown IOException");
} }
catch (IOException e) { catch (IOException e) {
@ -275,20 +248,16 @@ public class ExceptionTranslationFilterTests extends TestCase {
} }
} }
@Test
public void testThrowServletException() throws Exception { public void testThrowServletException() throws Exception {
ExceptionTranslationFilter filter = new ExceptionTranslationFilter(); ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
filter.setAuthenticationEntryPoint(mockEntryPoint()); filter.setAuthenticationEntryPoint(mockEntryPoint());
/* filter.afterPropertiesSet();
* Disabled the call to afterPropertiesSet as it requires FilterChain fc = mock(FilterChain.class);
* applicationContext to be injected before it is invoked. We do not doThrow(new ServletException()).when(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
* have this filter configured in IOC for this test hence no
* ApplicationContext
*/
// filter.afterPropertiesSet();
try { try {
filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), new MockFilterChain(false, filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), fc);
false, true, false));
fail("Should have thrown ServletException"); fail("Should have thrown ServletException");
} }
catch (ServletException e) { catch (ServletException e) {
@ -304,42 +273,4 @@ public class ExceptionTranslationFilterTests extends TestCase {
} }
}; };
} }
// ~ Inner Classes =================================================================================================
private class MockFilterChain implements FilterChain {
private boolean throwAccessDenied;
private boolean throwAuthenticationFailure;
private boolean throwIOException;
private boolean throwServletException;
public MockFilterChain(boolean throwAccessDenied, boolean throwAuthenticationFailure,
boolean throwServletException, boolean throwIOException) {
this.throwAccessDenied = throwAccessDenied;
this.throwAuthenticationFailure = throwAuthenticationFailure;
this.throwServletException = throwServletException;
this.throwIOException = throwIOException;
}
public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
if (throwAccessDenied) {
throw new AccessDeniedException("As requested");
}
if (throwAuthenticationFailure) {
throw new BadCredentialsException("As requested");
}
if (throwServletException) {
throw new ServletException("As requested");
}
if (throwIOException) {
throw new IOException("As requested");
}
}
}
} }

View File

@ -1,4 +1,4 @@
package org.springframework.security.web.wrapper; package org.springframework.security.web.savedrequest;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -14,15 +14,13 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.web.PortResolverImpl; import org.springframework.security.web.PortResolverImpl;
import org.springframework.security.web.savedrequest.FastHttpDateFormat; import org.springframework.security.web.savedrequest.FastHttpDateFormat;
import org.springframework.security.web.savedrequest.SavedRequest; import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.security.web.savedrequest.SavedRequestAwareWrapper;
public class SavedRequestAwareWrapperTests { public class SavedRequestAwareWrapperTests {
private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) { private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) {
if (requestToSave != null) { SavedRequest saved = requestToSave == null ? null : new SavedRequest(requestToSave, new PortResolverImpl());
SavedRequest savedRequest = new SavedRequest(requestToSave, new PortResolverImpl()); return new SavedRequestAwareWrapper(saved, requestToWrap);
requestToWrap.getSession().setAttribute(SavedRequest.SPRING_SECURITY_SAVED_REQUEST_KEY, savedRequest);
}
return new SavedRequestAwareWrapper(requestToWrap, new PortResolverImpl(),"ROLE_");
} }
@Test @Test
@ -128,7 +126,7 @@ public class SavedRequestAwareWrapperTests {
@Test @Test
public void getParameterValuesReturnsNullIfParameterIsntSet() { public void getParameterValuesReturnsNullIfParameterIsntSet() {
MockHttpServletRequest wrappedRequest = new MockHttpServletRequest(); MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(wrappedRequest, new PortResolverImpl(), "ROLE_"); SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(null, wrappedRequest);
assertNull(wrapper.getParameterValues("action")); assertNull(wrapper.getParameterValues("action"));
assertNull(wrapper.getParameterMap().get("action")); assertNull(wrapper.getParameterMap().get("action"));
} }

View File

@ -25,8 +25,6 @@ import org.jmock.integration.junit4.JUnit4Mockery;
import org.junit.Test; import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.PortResolverImpl;
import org.springframework.security.web.wrapper.SecurityContextHolderAwareRequestFilter;
/** /**
@ -43,15 +41,13 @@ public class SecurityContextHolderAwareRequestFilterTests {
@Test @Test
public void expectedRequestWrapperClassIsUsed() throws Exception { public void expectedRequestWrapperClassIsUsed() throws Exception {
SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter(); SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter();
filter.setPortResolver(new PortResolverImpl());
filter.setWrapperClass(SavedRequestAwareWrapper.class);
filter.setRolePrefix("ROLE_"); filter.setRolePrefix("ROLE_");
filter.init(jmock.mock(FilterConfig.class)); filter.init(jmock.mock(FilterConfig.class));
final FilterChain filterChain = jmock.mock(FilterChain.class); final FilterChain filterChain = jmock.mock(FilterChain.class);
jmock.checking(new Expectations() {{ jmock.checking(new Expectations() {{
exactly(2).of(filterChain).doFilter( exactly(2).of(filterChain).doFilter(
with(aNonNull(SavedRequestAwareWrapper.class)), with(aNonNull(HttpServletResponse.class))); with(aNonNull(SecurityContextHolderAwareRequestWrapper.class)), with(aNonNull(HttpServletResponse.class)));
}}); }});
filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), filterChain); filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), filterChain);

View File

@ -23,9 +23,6 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.User;
import org.springframework.security.web.PortResolverImpl;
import org.springframework.security.web.wrapper.SecurityContextHolderAwareRequestWrapper;
/** /**
* Tests {@link SecurityContextHolderAwareRequestWrapper}. * Tests {@link SecurityContextHolderAwareRequestWrapper}.
@ -34,17 +31,6 @@ import org.springframework.security.web.wrapper.SecurityContextHolderAwareReques
* @version $Id$ * @version $Id$
*/ */
public class SecurityContextHolderAwareRequestWrapperTests extends TestCase { public class SecurityContextHolderAwareRequestWrapperTests extends TestCase {
//~ Constructors ===================================================================================================
public SecurityContextHolderAwareRequestWrapperTests() {
}
public SecurityContextHolderAwareRequestWrapperTests(String arg0) {
super(arg0);
}
//~ Methods ========================================================================================================
protected void tearDown() throws Exception { protected void tearDown() throws Exception {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
@ -57,7 +43,7 @@ public class SecurityContextHolderAwareRequestWrapperTests extends TestCase {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setRequestURI("/"); request.setRequestURI("/");
SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request, new PortResolverImpl(), ""); SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request, "");
assertEquals("rod", wrapper.getRemoteUser()); assertEquals("rod", wrapper.getRemoteUser());
assertTrue(wrapper.isUserInRole("ROLE_FOO")); assertTrue(wrapper.isUserInRole("ROLE_FOO"));
@ -72,7 +58,7 @@ public class SecurityContextHolderAwareRequestWrapperTests extends TestCase {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setRequestURI("/"); request.setRequestURI("/");
SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request, new PortResolverImpl(), "ROLE_"); SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request, "ROLE_");
assertTrue(wrapper.isUserInRole("FOO")); assertTrue(wrapper.isUserInRole("FOO"));
} }
@ -85,7 +71,7 @@ public class SecurityContextHolderAwareRequestWrapperTests extends TestCase {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setRequestURI("/"); request.setRequestURI("/");
SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request, new PortResolverImpl(), ""); SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request, "");
assertEquals("rodAsUserDetails", wrapper.getRemoteUser()); assertEquals("rodAsUserDetails", wrapper.getRemoteUser());
assertFalse(wrapper.isUserInRole("ROLE_FOO")); assertFalse(wrapper.isUserInRole("ROLE_FOO"));
@ -101,7 +87,7 @@ public class SecurityContextHolderAwareRequestWrapperTests extends TestCase {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setRequestURI("/"); request.setRequestURI("/");
SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request,new PortResolverImpl(), ""); SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request, "");
assertNull(wrapper.getRemoteUser()); assertNull(wrapper.getRemoteUser());
assertFalse(wrapper.isUserInRole("ROLE_ANY")); assertFalse(wrapper.isUserInRole("ROLE_ANY"));
assertNull(wrapper.getUserPrincipal()); assertNull(wrapper.getUserPrincipal());
@ -114,7 +100,7 @@ public class SecurityContextHolderAwareRequestWrapperTests extends TestCase {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setRequestURI("/"); request.setRequestURI("/");
SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request, new PortResolverImpl(), ""); SecurityContextHolderAwareRequestWrapper wrapper = new SecurityContextHolderAwareRequestWrapper(request, "");
assertNull(wrapper.getRemoteUser()); assertNull(wrapper.getRemoteUser());
assertFalse(wrapper.isUserInRole("ROLE_HELLO")); // principal is null, so reject assertFalse(wrapper.isUserInRole("ROLE_HELLO")); // principal is null, so reject