diff --git a/core/src/main/java/org/springframework/security/firewall/RequestWrapper.java b/core/src/main/java/org/springframework/security/firewall/RequestWrapper.java index 352749a568..f490976ed0 100644 --- a/core/src/main/java/org/springframework/security/firewall/RequestWrapper.java +++ b/core/src/main/java/org/springframework/security/firewall/RequestWrapper.java @@ -1,6 +1,12 @@ package org.springframework.security.firewall; +import javax.servlet.RequestDispatcher; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; + +import java.io.IOException; import java.util.*; /** @@ -92,7 +98,43 @@ final class RequestWrapper extends FirewalledRequest { return stripPaths ? strippedServletPath : super.getServletPath(); } + public RequestDispatcher getRequestDispatcher(String path) { + return this.stripPaths ? new FirewalledRequestAwareRequestDispatcher(path) : super.getRequestDispatcher(path); + } + public void reset() { this.stripPaths = false; } + + /** + * Ensures {@link FirewalledRequest#reset()} is called prior to performing a forward. It then delegates work to the + * {@link RequestDispatcher} from the original {@link HttpServletRequest}. + * + * @author Rob Winch + */ + private class FirewalledRequestAwareRequestDispatcher implements RequestDispatcher { + private final String path; + + /** + * + * @param path the {@code path} that will be used to obtain the delegate {@link RequestDispatcher} from the + * original {@link HttpServletRequest}. + */ + public FirewalledRequestAwareRequestDispatcher(String path) { + this.path = path; + } + + public void forward(ServletRequest request, ServletResponse response) throws ServletException, IOException { + reset(); + getDelegateDispatcher().forward(request, response); + } + + public void include(ServletRequest request, ServletResponse response) throws ServletException, IOException { + getDelegateDispatcher().include(request, response); + } + + private RequestDispatcher getDelegateDispatcher() { + return RequestWrapper.super.getRequestDispatcher(path); + } + } } diff --git a/core/src/test/java/org/springframework/security/firewall/RequestWrapperTests.java b/core/src/test/java/org/springframework/security/firewall/RequestWrapperTests.java index d8d117bf4e..e3784f6dc4 100644 --- a/core/src/test/java/org/springframework/security/firewall/RequestWrapperTests.java +++ b/core/src/test/java/org/springframework/security/firewall/RequestWrapperTests.java @@ -1,13 +1,19 @@ package org.springframework.security.firewall; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import java.util.LinkedHashMap; +import java.util.Map; + +import javax.servlet.RequestDispatcher; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.junit.BeforeClass; import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; -import java.util.*; - /** * @author Luke Taylor */ @@ -59,4 +65,40 @@ public class RequestWrapperTests { } } + @Test + public void resetWhenForward() throws Exception { + String denormalizedPath = testPaths.keySet().iterator().next(); + String forwardPath = "/forward/path"; + HttpServletRequest mockRequest = mock(HttpServletRequest.class); + HttpServletResponse mockResponse = mock(HttpServletResponse.class); + RequestDispatcher mockDispatcher = mock(RequestDispatcher.class); + when(mockRequest.getServletPath()).thenReturn(""); + when(mockRequest.getPathInfo()).thenReturn(denormalizedPath); + when(mockRequest.getRequestDispatcher(forwardPath)).thenReturn(mockDispatcher); + + RequestWrapper wrapper = new RequestWrapper(mockRequest); + RequestDispatcher dispatcher = wrapper.getRequestDispatcher(forwardPath); + dispatcher.forward(mockRequest, mockResponse); + + verify(mockRequest).getRequestDispatcher(forwardPath); + verify(mockDispatcher).forward(mockRequest, mockResponse); + assertEquals(denormalizedPath,wrapper.getPathInfo()); + verify(mockRequest,times(2)).getPathInfo(); + // validate wrapper.getServletPath() delegates to the mock + wrapper.getServletPath(); + verify(mockRequest,times(2)).getServletPath(); + verifyNoMoreInteractions(mockRequest,mockResponse,mockDispatcher); + } + + @Test + public void requestDispatcherNotWrappedAfterReset() { + String path = "/forward/path"; + HttpServletRequest request = mock(HttpServletRequest.class); + RequestDispatcher dispatcher = mock(RequestDispatcher.class); + when(request.getRequestDispatcher(path)).thenReturn(dispatcher); + RequestWrapper wrapper = new RequestWrapper(request); + wrapper.reset(); + assertSame(dispatcher, wrapper.getRequestDispatcher(path)); + } + }