SEC-1012: Minor improvements to SecurityContextHolderAwareRequestFilter and conversion to use jmock for test.

This commit is contained in:
Luke Taylor 2008-12-06 17:31:53 +00:00
parent 953a4ab9ea
commit c3d216e7bb
2 changed files with 47 additions and 72 deletions

View File

@ -32,13 +32,15 @@ import org.springframework.util.ReflectionUtils;
/** /**
* A <code>Filter</code> which populates the <code>ServletRequest</code> with a new request wrapper.<p>Several * A <code>Filter</code> which populates the <code>ServletRequest</code> with a new request wrapper.
* request wrappers are included with the framework. The simplest version is {@link * Several request wrappers are included with the framework. The simplest version is {@link
* SecurityContextHolderAwareRequestWrapper}. A more complex and powerful request wrapper is {@link * SecurityContextHolderAwareRequestWrapper}. A more complex and powerful request wrapper is
* org.springframework.security.wrapper.SavedRequestAwareWrapper}. The latter is also the default.</p> * {@link SavedRequestAwareWrapper}. The latter is also the default.
* <p>To modify the wrapper used, call {@link #setWrapperClass(Class)}.</p> * <p>
* <p>Any request wrapper configured for instantiation by this class must provide a public constructor that * To modify the wrapper used, call {@link #setWrapperClass(Class)}.
* accepts two arguments, being a <code>HttpServletRequest</code> and a <code>PortResolver</code>.</p> * <p>
* Any request wrapper configured for instantiation by this class must provide a public constructor that
* accepts two arguments, being a <code>HttpServletRequest</code> and a <code>PortResolver</code>.
* *
* @author Orlando Garcia Carmona * @author Orlando Garcia Carmona
* @author Ben Alex * @author Ben Alex
@ -47,8 +49,8 @@ import org.springframework.util.ReflectionUtils;
public class SecurityContextHolderAwareRequestFilter extends SpringSecurityFilter { public class SecurityContextHolderAwareRequestFilter extends SpringSecurityFilter {
//~ Instance fields ================================================================================================ //~ Instance fields ================================================================================================
private Class wrapperClass = SavedRequestAwareWrapper.class; private Class<? extends HttpServletRequest> wrapperClass = SavedRequestAwareWrapper.class;
private Constructor constructor; private Constructor<? extends HttpServletRequest> constructor;
private PortResolver portResolver = new PortResolverImpl(); private PortResolver portResolver = new PortResolverImpl();
private String rolePrefix; private String rolePrefix;
@ -59,6 +61,7 @@ public class SecurityContextHolderAwareRequestFilter extends SpringSecurityFilte
this.portResolver = portResolver; this.portResolver = portResolver;
} }
@SuppressWarnings("unchecked")
public void setWrapperClass(Class wrapperClass) { public void setWrapperClass(Class wrapperClass) {
Assert.notNull(wrapperClass, "WrapperClass required"); Assert.notNull(wrapperClass, "WrapperClass required");
Assert.isTrue(HttpServletRequest.class.isAssignableFrom(wrapperClass), "Wrapper must be a HttpServletRequest"); Assert.isTrue(HttpServletRequest.class.isAssignableFrom(wrapperClass), "Wrapper must be a HttpServletRequest");
@ -72,17 +75,12 @@ public class SecurityContextHolderAwareRequestFilter extends SpringSecurityFilte
protected void doFilterHttp(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException { protected void doFilterHttp(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException {
if (!wrapperClass.isAssignableFrom(request.getClass())) { if (!wrapperClass.isAssignableFrom(request.getClass())) {
if (constructor == null) {
try { try {
constructor = wrapperClass.getConstructor( if (constructor == null) {
new Class[] {HttpServletRequest.class, PortResolver.class, String.class}); constructor = wrapperClass.getConstructor(HttpServletRequest.class, PortResolver.class, String.class);
} catch (Exception ex) {
ReflectionUtils.handleReflectionException(ex);
}
} }
try { request = constructor.newInstance(request, portResolver, rolePrefix);
request = (HttpServletRequest) constructor.newInstance(new Object[] {request, portResolver, rolePrefix});
} catch (Exception ex) { } catch (Exception ex) {
ReflectionUtils.handleReflectionException(ex); ReflectionUtils.handleReflectionException(ex);
} }

View File

@ -15,19 +15,17 @@
package org.springframework.security.wrapper; package org.springframework.security.wrapper;
import junit.framework.TestCase; import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import org.springframework.security.MockFilterConfig; import javax.servlet.http.HttpServletResponse;
import org.jmock.Expectations;
import org.jmock.Mockery;
import org.jmock.integration.junit4.JUnit4Mockery;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.util.PortResolverImpl;
import java.io.IOException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
/** /**
@ -36,51 +34,30 @@ import javax.servlet.ServletResponse;
* @author Ben Alex * @author Ben Alex
* @version $Id$ * @version $Id$
*/ */
public class SecurityContextHolderAwareRequestFilterTests extends TestCase { public class SecurityContextHolderAwareRequestFilterTests {
//~ Constructors =================================================================================================== Mockery jmock = new JUnit4Mockery();
public SecurityContextHolderAwareRequestFilterTests() {
}
public SecurityContextHolderAwareRequestFilterTests(String arg0) {
super(arg0);
}
//~ Methods ======================================================================================================== //~ Methods ========================================================================================================
public final void setUp() throws Exception { @Test
super.setUp(); public void expectedRequestWrapperClassIsUsed() throws Exception {
}
public void testCorrectOperation() throws Exception {
SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter(); SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter();
filter.init(new MockFilterConfig()); filter.setPortResolver(new PortResolverImpl());
filter.doFilter(new MockHttpServletRequest(null, null), new MockHttpServletResponse(), filter.setWrapperClass(SavedRequestAwareWrapper.class);
new MockFilterChain(SavedRequestAwareWrapper.class)); filter.setRolePrefix("ROLE_");
filter.init(jmock.mock(FilterConfig.class));
final FilterChain filterChain = jmock.mock(FilterChain.class);
jmock.checking(new Expectations() {{
exactly(2).of(filterChain).doFilter(
with(aNonNull(SavedRequestAwareWrapper.class)), with(aNonNull(HttpServletResponse.class)));
}});
filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), filterChain);
// Now re-execute the filter, ensuring our replacement wrapper is still used // Now re-execute the filter, ensuring our replacement wrapper is still used
filter.doFilter(new MockHttpServletRequest(null, null), new MockHttpServletResponse(), filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), filterChain);
new MockFilterChain(SavedRequestAwareWrapper.class));
filter.destroy(); filter.destroy();
} }
//~ Inner Classes ==================================================================================================
private class MockFilterChain implements FilterChain {
private Class expectedServletRequest;
public MockFilterChain(Class expectedServletRequest) {
this.expectedServletRequest = expectedServletRequest;
}
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
if (request.getClass().isAssignableFrom(expectedServletRequest)) {
assertTrue(true);
} else {
fail("Expected class to be of type " + expectedServletRequest + " but was: " + request.getClass());
}
}
}
} }