diff --git a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java index 773fe01c36..5640b7dc78 100644 --- a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java +++ b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java @@ -28,34 +28,40 @@ public class DefaultRedirectStrategy implements RedirectStrategy { * redirect is being performed to change to HTTPS, for example. */ public void sendRedirect(HttpServletRequest request, HttpServletResponse response, String url) throws IOException { - String finalUrl; - if (!url.startsWith("http://") && !url.startsWith("https://")) { - if (contextRelative) { - finalUrl = url; - } - else { - finalUrl = request.getContextPath() + url; - } - } - else if (contextRelative) { - // Calculate the relative URL from the fully qualifed URL, minus the protocol and base context. - int len = request.getContextPath().length(); - int index = url.indexOf(request.getContextPath()) + len; - finalUrl = url.substring(index); - - if (finalUrl.length() > 1 && finalUrl.charAt(0) == '/') { - finalUrl = finalUrl.substring(1); - } - } - else { - finalUrl = url; - } + String redirectUrl = calculateRedirectUrl(request.getContextPath(), url); + redirectUrl = response.encodeRedirectURL(redirectUrl); if (logger.isDebugEnabled()) { - logger.debug("Redirecting to '" + finalUrl + "'"); + logger.debug("Redirecting to '" + redirectUrl + "'"); } - response.sendRedirect(response.encodeRedirectURL(finalUrl)); + response.sendRedirect(redirectUrl); + } + + private String calculateRedirectUrl(String contextPath, String url) { + if (!url.startsWith("http://") && !url.startsWith("https://")) { + if (contextRelative) { + return url; + } else { + return contextPath + url; + } + } + + // Full URL, including http(s):// + + if (!contextRelative) { + return url; + } + + // Calculate the relative URL from the fully qualifed URL, minus the protocol and base context. + url = url.substring(url.indexOf("://") + 3); // strip off protocol + url = url.substring(url.indexOf(contextPath) + contextPath.length()); + + if (url.length() > 1 && url.charAt(0) == '/') { + url = url.substring(1); + } + + return url; } /** diff --git a/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java b/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java new file mode 100644 index 0000000000..b91edff2d0 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java @@ -0,0 +1,27 @@ +package org.springframework.security.web; + +import static org.junit.Assert.*; + +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +/** + * + * @author Luke Taylor + * @since 3.0 + */ +public class DefaultRedirectStrategyTests { + @Test + public void contextRelativeUrlWithContextNameInHostnameIsHandledCorrectly() throws Exception { + DefaultRedirectStrategy rds = new DefaultRedirectStrategy(); + rds.setContextRelative(true); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setContextPath("/context"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + rds.sendRedirect(request, response, "http://context.blah.com/context/remainder"); + + assertEquals("remainder", response.getRedirectedUrl()); + } +}