mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-07-23 10:43:30 +00:00
Ensure consistent matching of redirect_uri
Fixes gh-5756
This commit is contained in:
parent
566fb939ca
commit
9a49795abc
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright 2002-2017 the original author or authors.
|
* Copyright 2002-2018 the original author or authors.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
@ -25,6 +25,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
|
|||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
import org.springframework.security.web.DefaultRedirectStrategy;
|
import org.springframework.security.web.DefaultRedirectStrategy;
|
||||||
import org.springframework.security.web.RedirectStrategy;
|
import org.springframework.security.web.RedirectStrategy;
|
||||||
|
import org.springframework.security.web.util.UrlUtils;
|
||||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
import org.springframework.web.filter.OncePerRequestFilter;
|
import org.springframework.web.filter.OncePerRequestFilter;
|
||||||
@ -183,16 +184,9 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
|||||||
}
|
}
|
||||||
|
|
||||||
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) {
|
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) {
|
||||||
int port = request.getServerPort();
|
String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
|
||||||
if (("http".equals(request.getScheme()) && port == 80) || ("https".equals(request.getScheme()) && port == 443)) {
|
.replaceQuery(null)
|
||||||
port = -1; // Removes the port in UriComponentsBuilder
|
.replacePath(request.getContextPath())
|
||||||
}
|
|
||||||
|
|
||||||
String baseUrl = UriComponentsBuilder.newInstance()
|
|
||||||
.scheme(request.getScheme())
|
|
||||||
.host(request.getServerName())
|
|
||||||
.port(port)
|
|
||||||
.path(request.getContextPath())
|
|
||||||
.build()
|
.build()
|
||||||
.toUriString();
|
.toUriString();
|
||||||
|
|
||||||
|
@ -34,8 +34,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp
|
|||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
|
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
|
||||||
import org.springframework.security.web.context.SecurityContextRepository;
|
import org.springframework.security.web.context.SecurityContextRepository;
|
||||||
|
import org.springframework.security.web.util.UrlUtils;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
import org.springframework.util.StringUtils;
|
import org.springframework.util.StringUtils;
|
||||||
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
|
|
||||||
import javax.servlet.ServletException;
|
import javax.servlet.ServletException;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
@ -192,7 +194,10 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
|
|||||||
String code = request.getParameter(OAuth2ParameterNames.CODE);
|
String code = request.getParameter(OAuth2ParameterNames.CODE);
|
||||||
String errorCode = request.getParameter(OAuth2ParameterNames.ERROR);
|
String errorCode = request.getParameter(OAuth2ParameterNames.ERROR);
|
||||||
String state = request.getParameter(OAuth2ParameterNames.STATE);
|
String state = request.getParameter(OAuth2ParameterNames.STATE);
|
||||||
String redirectUri = request.getRequestURL().toString();
|
String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
|
||||||
|
.replaceQuery(null)
|
||||||
|
.build()
|
||||||
|
.toUriString();
|
||||||
|
|
||||||
if (StringUtils.hasText(code)) {
|
if (StringUtils.hasText(code)) {
|
||||||
return OAuth2AuthorizationResponse.success(code)
|
return OAuth2AuthorizationResponse.success(code)
|
||||||
|
@ -43,9 +43,12 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|||||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
import org.springframework.security.oauth2.core.user.OAuth2User;
|
import org.springframework.security.oauth2.core.user.OAuth2User;
|
||||||
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
||||||
|
import org.springframework.security.web.util.UrlUtils;
|
||||||
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
|
|
||||||
import javax.servlet.FilterChain;
|
import javax.servlet.FilterChain;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
@ -64,7 +67,7 @@ import static org.powermock.api.mockito.PowerMockito.verifyPrivate;
|
|||||||
* @author Joe Grandja
|
* @author Joe Grandja
|
||||||
*/
|
*/
|
||||||
@PowerMockIgnore("javax.security.*")
|
@PowerMockIgnore("javax.security.*")
|
||||||
@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class})
|
@PrepareForTest({OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class})
|
||||||
@RunWith(PowerMockRunner.class)
|
@RunWith(PowerMockRunner.class)
|
||||||
public class OAuth2LoginAuthenticationFilterTests {
|
public class OAuth2LoginAuthenticationFilterTests {
|
||||||
private ClientRegistration registration1;
|
private ClientRegistration registration1;
|
||||||
@ -322,15 +325,133 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gh-5756
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenAuthorizationResponseHasDefaultPort80ThenRedirectUriMatchingExcludesPort() throws Exception {
|
||||||
|
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
|
request.setScheme("http");
|
||||||
|
request.setServerName("example.com");
|
||||||
|
request.setServerPort(80);
|
||||||
|
request.setServletPath(requestUri);
|
||||||
|
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||||
|
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
this.setUpAuthorizationRequest(request, response, this.registration2);
|
||||||
|
this.setUpAuthenticationResult(this.registration2);
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
|
||||||
|
verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture());
|
||||||
|
|
||||||
|
OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue();
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest();
|
||||||
|
OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse();
|
||||||
|
|
||||||
|
String expectedRedirectUri = "http://example.com/login/oauth2/code/registration-2";
|
||||||
|
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
||||||
|
assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
||||||
|
}
|
||||||
|
|
||||||
|
// gh-5756
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenAuthorizationResponseHasDefaultPort443ThenRedirectUriMatchingExcludesPort() throws Exception {
|
||||||
|
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
|
request.setScheme("https");
|
||||||
|
request.setServerName("example.com");
|
||||||
|
request.setServerPort(443);
|
||||||
|
request.setServletPath(requestUri);
|
||||||
|
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||||
|
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
this.setUpAuthorizationRequest(request, response, this.registration2);
|
||||||
|
this.setUpAuthenticationResult(this.registration2);
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
|
||||||
|
verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture());
|
||||||
|
|
||||||
|
OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue();
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest();
|
||||||
|
OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse();
|
||||||
|
|
||||||
|
String expectedRedirectUri = "https://example.com/login/oauth2/code/registration-2";
|
||||||
|
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
||||||
|
assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
||||||
|
}
|
||||||
|
|
||||||
|
// gh-5756
|
||||||
|
@Test
|
||||||
|
public void doFilterWhenAuthorizationResponseHasNonDefaultPortThenRedirectUriMatchingIncludesPort() throws Exception {
|
||||||
|
String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
|
request.setScheme("https");
|
||||||
|
request.setServerName("example.com");
|
||||||
|
request.setServerPort(9090);
|
||||||
|
request.setServletPath(requestUri);
|
||||||
|
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||||
|
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||||
|
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
|
||||||
|
this.setUpAuthorizationRequest(request, response, this.registration2);
|
||||||
|
this.setUpAuthenticationResult(this.registration2);
|
||||||
|
|
||||||
|
this.filter.doFilter(request, response, filterChain);
|
||||||
|
|
||||||
|
ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
|
||||||
|
verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture());
|
||||||
|
|
||||||
|
OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue();
|
||||||
|
OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest();
|
||||||
|
OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse();
|
||||||
|
|
||||||
|
String expectedRedirectUri = "https://example.com:9090/login/oauth2/code/registration-2";
|
||||||
|
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
||||||
|
assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
||||||
|
}
|
||||||
|
|
||||||
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
||||||
ClientRegistration registration) {
|
ClientRegistration registration) {
|
||||||
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
|
||||||
Map<String, Object> additionalParameters = new HashMap<>();
|
Map<String, Object> additionalParameters = new HashMap<>();
|
||||||
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
||||||
when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
|
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
|
||||||
|
.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
|
||||||
|
.clientId(registration.getClientId())
|
||||||
|
.redirectUri(expandRedirectUri(request, registration))
|
||||||
|
.scopes(registration.getScopes())
|
||||||
|
.state("state")
|
||||||
|
.additionalParameters(additionalParameters)
|
||||||
|
.build();
|
||||||
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) {
|
||||||
|
String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
|
||||||
|
.replaceQuery(null)
|
||||||
|
.replacePath(request.getContextPath())
|
||||||
|
.build()
|
||||||
|
.toUriString();
|
||||||
|
|
||||||
|
Map<String, String> uriVariables = new HashMap<>();
|
||||||
|
uriVariables.put("baseUrl", baseUrl);
|
||||||
|
uriVariables.put("registrationId", clientRegistration.getRegistrationId());
|
||||||
|
|
||||||
|
return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
|
||||||
|
.buildAndExpand(uriVariables)
|
||||||
|
.toUriString();
|
||||||
|
}
|
||||||
|
|
||||||
private void setUpAuthenticationResult(ClientRegistration registration) {
|
private void setUpAuthenticationResult(ClientRegistration registration) {
|
||||||
OAuth2User user = mock(OAuth2User.class);
|
OAuth2User user = mock(OAuth2User.class);
|
||||||
when(user.getName()).thenReturn(this.principalName1);
|
when(user.getName()).thenReturn(this.principalName1);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user