diff --git a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java index 42055203f4..2c0d8ae0f3 100644 --- a/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java +++ b/web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java @@ -7,6 +7,7 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.security.web.util.UrlUtils; /** * Simple implementation of RedirectStrategy which is the default used throughout the framework. @@ -15,6 +16,7 @@ import org.apache.commons.logging.LogFactory; * @since 3.0 */ public class DefaultRedirectStrategy implements RedirectStrategy { + protected final Log logger = LogFactory.getLog(getClass()); private boolean contextRelative; @@ -38,7 +40,7 @@ public class DefaultRedirectStrategy implements RedirectStrategy { } private String calculateRedirectUrl(String contextPath, String url) { - if (!url.startsWith("http://") && !url.startsWith("https://")) { + if (!UrlUtils.isAbsoluteUrl(url)) { if (contextRelative) { return url; } else { @@ -52,8 +54,8 @@ public class DefaultRedirectStrategy implements RedirectStrategy { return url; } - // Calculate the relative URL from the fully qualifed URL, minus the protocol and base context. - url = url.substring(url.indexOf("://") + 3); // strip off protocol + // Calculate the relative URL from the fully qualified URL, minus the scheme and base context. + url = url.substring(url.indexOf("://") + 3); // strip off scheme url = url.substring(url.indexOf(contextPath) + contextPath.length()); if (url.length() > 1 && url.charAt(0) == '/') { diff --git a/web/src/main/java/org/springframework/security/web/util/UrlUtils.java b/web/src/main/java/org/springframework/security/web/util/UrlUtils.java index 88d4f14d2c..fb679e0f21 100644 --- a/web/src/main/java/org/springframework/security/web/util/UrlUtils.java +++ b/web/src/main/java/org/springframework/security/web/util/UrlUtils.java @@ -15,6 +15,8 @@ package org.springframework.security.web.util; +import java.util.regex.Pattern; + import javax.servlet.http.HttpServletRequest; @@ -96,7 +98,7 @@ public final class UrlUtils { * Obtains the web application-specific fragment of the URL. */ private static String buildRequestUrl(String servletPath, String requestURI, String contextPath, String pathInfo, - String queryString) { + String queryString) { StringBuilder url = new StringBuilder(); @@ -117,9 +119,18 @@ public final class UrlUtils { } /** - * Returns true if the supplied URL starts with a "/" or "http". + * Returns true if the supplied URL starts with a "/" or is absolute. */ public static boolean isValidRedirectUrl(String url) { - return url != null && url.startsWith("/") || url.toLowerCase().startsWith("http"); + return url != null && url.startsWith("/") || isAbsoluteUrl(url); + } + + /** + * Decides if a URL is absolute based on whether it contains a valid scheme name, as defined in RFC 1738. + */ + public static boolean isAbsoluteUrl(String url) { + final Pattern ABSOLUTE_URL = Pattern.compile("\\A[a-z.+-]+://.*", Pattern.CASE_INSENSITIVE); + + return ABSOLUTE_URL.matcher(url).matches(); } } diff --git a/web/src/test/java/org/springframework/security/web/util/UrlUtilsTests.java b/web/src/test/java/org/springframework/security/web/util/UrlUtilsTests.java new file mode 100644 index 0000000000..4d521bc99e --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/util/UrlUtilsTests.java @@ -0,0 +1,22 @@ +package org.springframework.security.web.util; + +import static org.junit.Assert.*; + +import org.junit.Test; + +/** + * + * @author Luke Taylor + */ +public class UrlUtilsTests { + + @Test + public void absoluteUrlsAreMatchedAsAbsolute() throws Exception { + assertTrue(UrlUtils.isAbsoluteUrl("http://something/")); + assertTrue(UrlUtils.isAbsoluteUrl("HTTP://something/")); + assertTrue(UrlUtils.isAbsoluteUrl("https://something/")); + assertTrue(UrlUtils.isAbsoluteUrl("a://something/")); + assertTrue(UrlUtils.isAbsoluteUrl("zz+zz.zz-zz://something/")); + } + +}