OAuth2AuthorizationCodeGrantFilter matches on query parameters
Fixes gh-7963
This commit is contained in:
parent
d3490b0f87
commit
3c86239b39
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -41,6 +41,7 @@ import org.springframework.util.Assert;
|
|||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.filter.OncePerRequestFilter;
|
||||
import org.springframework.web.util.UriComponents;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import javax.servlet.FilterChain;
|
||||
|
@ -48,6 +49,11 @@ import javax.servlet.ServletException;
|
|||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.IOException;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
|
||||
|
@ -132,24 +138,39 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
|
|||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
||||
throws ServletException, IOException {
|
||||
|
||||
if (this.shouldProcessAuthorizationResponse(request)) {
|
||||
this.processAuthorizationResponse(request, response);
|
||||
if (matchesAuthorizationResponse(request)) {
|
||||
processAuthorizationResponse(request, response);
|
||||
return;
|
||||
}
|
||||
|
||||
filterChain.doFilter(request, response);
|
||||
}
|
||||
|
||||
private boolean shouldProcessAuthorizationResponse(HttpServletRequest request) {
|
||||
private boolean matchesAuthorizationResponse(HttpServletRequest request) {
|
||||
MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
|
||||
if (!OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) {
|
||||
return false;
|
||||
}
|
||||
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||
if (authorizationRequest == null) {
|
||||
return false;
|
||||
}
|
||||
String requestUrl = UrlUtils.buildFullRequestUrl(request.getScheme(), request.getServerName(),
|
||||
request.getServerPort(), request.getRequestURI(), null);
|
||||
MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
|
||||
if (requestUrl.equals(authorizationRequest.getRedirectUri()) &&
|
||||
OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) {
|
||||
|
||||
// Compare redirect_uri
|
||||
UriComponents requestUri = UriComponentsBuilder.fromUriString(UrlUtils.buildFullRequestUrl(request)).build();
|
||||
UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getRedirectUri()).build();
|
||||
Set<Map.Entry<String, List<String>>> requestUriParameters = new LinkedHashSet<>(requestUri.getQueryParams().entrySet());
|
||||
Set<Map.Entry<String, List<String>>> redirectUriParameters = new LinkedHashSet<>(redirectUri.getQueryParams().entrySet());
|
||||
// Remove the additional request parameters (if any) from the authorization response (request)
|
||||
// before doing an exact comparison with the authorizationRequest.getRedirectUri() parameters (if any)
|
||||
requestUriParameters.retainAll(redirectUriParameters);
|
||||
|
||||
if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) &&
|
||||
Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) &&
|
||||
Objects.equals(requestUri.getHost(), redirectUri.getHost()) &&
|
||||
Objects.equals(requestUri.getPort(), redirectUri.getPort()) &&
|
||||
Objects.equals(requestUri.getPath(), redirectUri.getPath()) &&
|
||||
Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -165,10 +186,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
|
|||
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
|
||||
|
||||
MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
|
||||
String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
|
||||
.replaceQuery(null)
|
||||
.build()
|
||||
.toUriString();
|
||||
String redirectUri = UrlUtils.buildFullRequestUrl(request);
|
||||
OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri);
|
||||
|
||||
OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken(
|
||||
|
@ -183,7 +201,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
|
|||
} catch (OAuth2AuthorizationException ex) {
|
||||
OAuth2Error error = ex.getError();
|
||||
UriComponentsBuilder uriBuilder = UriComponentsBuilder
|
||||
.fromUriString(authorizationResponse.getRedirectUri())
|
||||
.fromUriString(authorizationRequest.getRedirectUri())
|
||||
.queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode());
|
||||
if (!StringUtils.isEmpty(error.getDescription())) {
|
||||
uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription());
|
||||
|
@ -206,7 +224,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
|
|||
|
||||
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response);
|
||||
|
||||
String redirectUrl = authorizationResponse.getRedirectUri();
|
||||
String redirectUrl = authorizationRequest.getRedirectUri();
|
||||
SavedRequest savedRequest = this.requestCache.getRequest(request, response);
|
||||
if (savedRequest != null) {
|
||||
redirectUrl = savedRequest.getRedirectUrl();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,17 +15,9 @@
|
|||
*/
|
||||
package org.springframework.security.oauth2.client.web;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.servlet.http.HttpSession;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockHttpServletResponse;
|
||||
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
||||
|
@ -50,13 +42,26 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
|
|||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
|
||||
import org.springframework.security.web.savedrequest.RequestCache;
|
||||
import org.springframework.security.web.util.UrlUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.servlet.http.HttpSession;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
|
||||
import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken;
|
||||
|
@ -131,8 +136,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
// NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter.
|
||||
|
||||
HttpServletResponse response = mock(HttpServletResponse.class);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
@ -142,94 +146,142 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestNotFoundThenNotProcessed() throws Exception {
|
||||
String requestUri = "/path";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
HttpServletResponse response = mock(HttpServletResponse.class);
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationResponseUrlDoesNotMatchAuthorizationRequestRedirectUriThenNotProcessed() throws Exception {
|
||||
String requestUri = "/callback/client-1";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
HttpServletResponse response = mock(HttpServletResponse.class);
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1);
|
||||
request.setRequestURI(requestUri + "-no-match");
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
|
||||
String requestUri = "/callback/client-1";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/path");
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1);
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestRedirectUriDoesNotMatchThenNotProcessed() throws Exception {
|
||||
String requestUri = "/callback/client-1";
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri);
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
authorizationResponse.setRequestURI(requestUri + "-no-match");
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
// gh-7963
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() throws Exception {
|
||||
// 1) redirect_uri with query parameters
|
||||
String requestUri = "/callback/client-1";
|
||||
Map<String, String> parameters = new LinkedHashMap<>();
|
||||
parameters.put("param1", "value1");
|
||||
parameters.put("param2", "value2");
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
// 2) redirect_uri with query parameters AND authorization response additional parameters
|
||||
Map<String, String> additionalParameters = new LinkedHashMap<>();
|
||||
additionalParameters.put("auth-param1", "value1");
|
||||
additionalParameters.put("auth-param2", "value2");
|
||||
response = new MockHttpServletResponse();
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters);
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
verifyNoInteractions(filterChain);
|
||||
}
|
||||
|
||||
// gh-7963
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestRedirectUriParametersDoesNotMatchThenNotProcessed() throws Exception {
|
||||
String requestUri = "/callback/client-1";
|
||||
Map<String, String> parameters = new LinkedHashMap<>();
|
||||
parameters.put("param1", "value1");
|
||||
parameters.put("param2", "value2");
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
// 1) Parameter value
|
||||
Map<String, String> parametersNotMatch = new LinkedHashMap<>(parameters);
|
||||
parametersNotMatch.put("param2", "value8");
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(
|
||||
createAuthorizationRequest(requestUri, parametersNotMatch));
|
||||
authorizationResponse.setSession(authorizationRequest.getSession());
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
// 2) Parameter order
|
||||
parametersNotMatch = new LinkedHashMap<>();
|
||||
parametersNotMatch.put("param2", "value2");
|
||||
parametersNotMatch.put("param1", "value1");
|
||||
authorizationResponse = createAuthorizationResponse(
|
||||
createAuthorizationRequest(requestUri, parametersNotMatch));
|
||||
authorizationResponse.setSession(authorizationRequest.getSession());
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
verify(filterChain, times(2)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
|
||||
// 3) Parameter missing
|
||||
parametersNotMatch = new LinkedHashMap<>(parameters);
|
||||
parametersNotMatch.remove("param2");
|
||||
authorizationResponse = createAuthorizationResponse(
|
||||
createAuthorizationRequest(requestUri, parametersNotMatch));
|
||||
authorizationResponse.setSession(authorizationRequest.getSession());
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
verify(filterChain, times(3)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestMatchThenAuthorizationRequestRemoved() throws Exception {
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
|
||||
assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull();
|
||||
assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(authorizationResponse)).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationFailsThenHandleOAuth2AuthorizationException() throws Exception {
|
||||
String requestUri = "/callback/client-1";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1);
|
||||
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT);
|
||||
when(this.authenticationManager.authenticate(any(Authentication.class)))
|
||||
.thenThrow(new OAuth2AuthorizationException(error));
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
|
||||
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1?error=invalid_grant");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService() throws Exception {
|
||||
String requestUri = "/callback/client-1";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
public void doFilterWhenAuthorizationSucceedsThenAuthorizedClientSavedToService() throws Exception {
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1);
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
|
||||
this.registration1.getRegistrationId(), this.principalName1);
|
||||
|
@ -241,40 +293,31 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationResponseSuccessThenRedirected() throws Exception {
|
||||
String requestUri = "/callback/client-1";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
public void doFilterWhenAuthorizationSucceedsThenRedirected() throws Exception {
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1);
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
|
||||
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationResponseSuccessHasSavedRequestThenRedirectedToSavedRequest() throws Exception {
|
||||
public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception {
|
||||
String requestUri = "/saved-request";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
RequestCache requestCache = new HttpSessionRequestCache();
|
||||
requestCache.saveRequest(request, response);
|
||||
|
||||
requestUri = "/callback/client-1";
|
||||
request.setRequestURI(requestUri);
|
||||
request.setRequestURI("/callback/client-1");
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
|
||||
|
@ -284,36 +327,30 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
|
||||
public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
|
||||
AnonymousAuthenticationToken anonymousPrincipal =
|
||||
new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
|
||||
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
||||
securityContext.setAuthentication(anonymousPrincipal);
|
||||
SecurityContextHolder.setContext(securityContext);
|
||||
|
||||
String requestUri = "/callback/client-1";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1);
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
|
||||
this.registration1.getRegistrationId(), anonymousPrincipal, request);
|
||||
this.registration1.getRegistrationId(), anonymousPrincipal, authorizationResponse);
|
||||
assertThat(authorizedClient).isNotNull();
|
||||
|
||||
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
|
||||
assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName());
|
||||
assertThat(authorizedClient.getAccessToken()).isNotNull();
|
||||
|
||||
HttpSession session = request.getSession(false);
|
||||
HttpSession session = authorizationResponse.getSession(false);
|
||||
assertThat(session).isNotNull();
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -325,33 +362,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
|
||||
public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
|
||||
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
||||
SecurityContextHolder.setContext(securityContext); // null Authentication
|
||||
|
||||
String requestUri = "/callback/client-1";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
|
||||
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
||||
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1);
|
||||
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
this.filter.doFilter(authorizationResponse, response, filterChain);
|
||||
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
|
||||
this.registration1.getRegistrationId(), null, request);
|
||||
this.registration1.getRegistrationId(), null, authorizationResponse);
|
||||
assertThat(authorizedClient).isNotNull();
|
||||
|
||||
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
|
||||
assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser");
|
||||
assertThat(authorizedClient.getAccessToken()).isNotNull();
|
||||
|
||||
HttpSession session = request.getSession(false);
|
||||
HttpSession session = authorizationResponse.getSession(false);
|
||||
assertThat(session).isNotNull();
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -362,13 +393,51 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|||
assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient);
|
||||
}
|
||||
|
||||
private static MockHttpServletRequest createAuthorizationRequest(String requestUri) {
|
||||
return createAuthorizationRequest(requestUri, new LinkedHashMap<>());
|
||||
}
|
||||
|
||||
private static MockHttpServletRequest createAuthorizationRequest(String requestUri, Map<String, String> parameters) {
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
if (!CollectionUtils.isEmpty(parameters)) {
|
||||
parameters.forEach(request::addParameter);
|
||||
request.setQueryString(
|
||||
parameters.entrySet().stream()
|
||||
.map(e -> e.getKey() + "=" + e.getValue())
|
||||
.collect(Collectors.joining("&")));
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
private static MockHttpServletRequest createAuthorizationResponse(MockHttpServletRequest authorizationRequest) {
|
||||
return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>());
|
||||
}
|
||||
|
||||
private static MockHttpServletRequest createAuthorizationResponse(
|
||||
MockHttpServletRequest authorizationRequest, Map<String, String> additionalParameters) {
|
||||
MockHttpServletRequest authorizationResponse = new MockHttpServletRequest(
|
||||
authorizationRequest.getMethod(), authorizationRequest.getRequestURI());
|
||||
authorizationResponse.setServletPath(authorizationRequest.getRequestURI());
|
||||
authorizationRequest.getParameterMap().forEach(authorizationResponse::addParameter);
|
||||
authorizationResponse.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
authorizationResponse.addParameter(OAuth2ParameterNames.STATE, "state");
|
||||
additionalParameters.forEach(authorizationResponse::addParameter);
|
||||
authorizationResponse.setQueryString(
|
||||
authorizationResponse.getParameterMap().entrySet().stream()
|
||||
.map(e -> e.getKey() + "=" + e.getValue()[0])
|
||||
.collect(Collectors.joining("&")));
|
||||
authorizationResponse.setSession(authorizationRequest.getSession());
|
||||
return authorizationResponse;
|
||||
}
|
||||
|
||||
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
||||
ClientRegistration registration) {
|
||||
Map<String, Object> additionalParameters = new HashMap<>();
|
||||
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
||||
Map<String, Object> attributes = new HashMap<>();
|
||||
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
||||
OAuth2AuthorizationRequest authorizationRequest = request()
|
||||
.additionalParameters(additionalParameters)
|
||||
.redirectUri(request.getRequestURL().toString()).build();
|
||||
.attributes(attributes)
|
||||
.redirectUri(UrlUtils.buildFullRequestUrl(request)).build();
|
||||
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue