diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index 4fea8b7acb..e9cbd86f1f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; @@ -183,16 +184,9 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt } private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) { - int port = request.getServerPort(); - if (("http".equals(request.getScheme()) && port == 80) || ("https".equals(request.getScheme()) && port == 443)) { - port = -1; // Removes the port in UriComponentsBuilder - } - - String baseUrl = UriComponentsBuilder.newInstance() - .scheme(request.getScheme()) - .host(request.getServerName()) - .port(port) - .path(request.getContextPath()) + String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) + .replaceQuery(null) + .replacePath(request.getContextPath()) .build() .toUriString(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java index 51e87b6ccd..b482608f36 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java @@ -34,8 +34,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.context.SecurityContextRepository; +import org.springframework.security.web.util.UrlUtils; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -192,7 +194,10 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce String code = request.getParameter(OAuth2ParameterNames.CODE); String errorCode = request.getParameter(OAuth2ParameterNames.ERROR); String state = request.getParameter(OAuth2ParameterNames.STATE); - String redirectUri = request.getRequestURL().toString(); + String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) + .replaceQuery(null) + .build() + .toUriString(); if (StringUtils.hasText(code)) { return OAuth2AuthorizationResponse.success(code) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java index 91034fc0ea..5d225b02f1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java @@ -43,9 +43,12 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.util.UrlUtils; +import org.springframework.web.util.UriComponentsBuilder; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; @@ -64,7 +67,7 @@ import static org.powermock.api.mockito.PowerMockito.verifyPrivate; * @author Joe Grandja */ @PowerMockIgnore("javax.security.*") -@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class}) +@PrepareForTest({OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class}) @RunWith(PowerMockRunner.class) public class OAuth2LoginAuthenticationFilterTests { private ClientRegistration registration1; @@ -322,15 +325,133 @@ public class OAuth2LoginAuthenticationFilterTests { } } + // gh-5756 + @Test + public void doFilterWhenAuthorizationResponseHasDefaultPort80ThenRedirectUriMatchingExcludesPort() throws Exception { + String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setScheme("http"); + request.setServerName("example.com"); + request.setServerPort(80); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.setUpAuthorizationRequest(request, response, this.registration2); + this.setUpAuthenticationResult(this.registration2); + + this.filter.doFilter(request, response, filterChain); + + ArgumentCaptor authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class); + verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture()); + + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue(); + OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest(); + OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse(); + + String expectedRedirectUri = "http://example.com/login/oauth2/code/registration-2"; + assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri); + assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri); + } + + // gh-5756 + @Test + public void doFilterWhenAuthorizationResponseHasDefaultPort443ThenRedirectUriMatchingExcludesPort() throws Exception { + String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setScheme("https"); + request.setServerName("example.com"); + request.setServerPort(443); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.setUpAuthorizationRequest(request, response, this.registration2); + this.setUpAuthenticationResult(this.registration2); + + this.filter.doFilter(request, response, filterChain); + + ArgumentCaptor authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class); + verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture()); + + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue(); + OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest(); + OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse(); + + String expectedRedirectUri = "https://example.com/login/oauth2/code/registration-2"; + assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri); + assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri); + } + + // gh-5756 + @Test + public void doFilterWhenAuthorizationResponseHasNonDefaultPortThenRedirectUriMatchingIncludesPort() throws Exception { + String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setScheme("https"); + request.setServerName("example.com"); + request.setServerPort(9090); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.setUpAuthorizationRequest(request, response, this.registration2); + this.setUpAuthenticationResult(this.registration2); + + this.filter.doFilter(request, response, filterChain); + + ArgumentCaptor authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class); + verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture()); + + OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue(); + OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest(); + OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse(); + + String expectedRedirectUri = "https://example.com:9090/login/oauth2/code/registration-2"; + assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri); + assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri); + } + private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, ClientRegistration registration) { - OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class); Map additionalParameters = new HashMap<>(); additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); - when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters); + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) + .clientId(registration.getClientId()) + .redirectUri(expandRedirectUri(request, registration)) + .scopes(registration.getScopes()) + .state("state") + .additionalParameters(additionalParameters) + .build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); } + private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) { + String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) + .replaceQuery(null) + .replacePath(request.getContextPath()) + .build() + .toUriString(); + + Map uriVariables = new HashMap<>(); + uriVariables.put("baseUrl", baseUrl); + uriVariables.put("registrationId", clientRegistration.getRegistrationId()); + + return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate()) + .buildAndExpand(uriVariables) + .toUriString(); + } + private void setUpAuthenticationResult(ClientRegistration registration) { OAuth2User user = mock(OAuth2User.class); when(user.getName()).thenReturn(this.principalName1);