diff --git a/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java b/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java index 95cc2db912..817feb6967 100644 --- a/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java +++ b/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java @@ -18,6 +18,7 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.web.util.WebUtils; /** * A {@code SecurityContextRepository} implementation which stores the security context in the {@code HttpSession} @@ -105,7 +106,10 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo } public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) { - SaveContextOnUpdateOrErrorResponseWrapper responseWrapper = (SaveContextOnUpdateOrErrorResponseWrapper)response; + SaveContextOnUpdateOrErrorResponseWrapper responseWrapper = WebUtils.getNativeResponse(response, SaveContextOnUpdateOrErrorResponseWrapper.class); + if(responseWrapper == null) { + throw new IllegalStateException("Cannot invoke saveContext on response " + response + ". You must use the HttpRequestResponseHolder.response after invoking loadContext"); + } // saveContext() might already be called by the response wrapper // if something in the chain called sendError() or sendRedirect(). This ensures we only call it // once per request. diff --git a/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java b/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java index 6a7b65e359..4f9c7e3610 100644 --- a/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java @@ -29,9 +29,7 @@ import static org.springframework.security.web.context.HttpSessionSecurityContex import javax.servlet.ServletOutputStream; import javax.servlet.ServletRequest; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; +import javax.servlet.http.*; import org.junit.After; import org.junit.Test; @@ -495,4 +493,33 @@ public class HttpSessionSecurityContextRepositoryTests { HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); repo.setTrustResolver(null); } -} + + // SEC-2578 + @Test + public void traverseWrappedRequests() { + HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response); + SecurityContext context = repo.loadContext(holder); + assertNull(request.getSession(false)); + // Simulate authentication during the request + context.setAuthentication(testToken); + + repo.saveContext(context, new HttpServletRequestWrapper(holder.getRequest()), new HttpServletResponseWrapper(holder.getResponse())); + + assertNotNull(request.getSession(false)); + assertEquals(context, request.getSession().getAttribute(SPRING_SECURITY_CONTEXT_KEY)); + } + + @Test(expected = IllegalStateException.class) + public void failsWithStandardResponse() { + HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + SecurityContext context = SecurityContextHolder.createEmptyContext(); + context.setAuthentication(testToken); + + repo.saveContext(context,request,response); + } +} \ No newline at end of file