OAuth2AuthorizationCodeGrantFilter matches on query parameters

Fixes gh-7963
This commit is contained in:
Joe Grandja 2019-11-11 06:36:02 -05:00
parent d3490b0f87
commit 3c86239b39
2 changed files with 221 additions and 134 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -41,6 +41,7 @@ import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.FilterChain;
@ -48,6 +49,11 @@ import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/**
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
@ -132,24 +138,39 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
if (this.shouldProcessAuthorizationResponse(request)) {
this.processAuthorizationResponse(request, response);
if (matchesAuthorizationResponse(request)) {
processAuthorizationResponse(request, response);
return;
}
filterChain.doFilter(request, response);
}
private boolean shouldProcessAuthorizationResponse(HttpServletRequest request) {
private boolean matchesAuthorizationResponse(HttpServletRequest request) {
MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
if (!OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) {
return false;
}
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request);
if (authorizationRequest == null) {
return false;
}
String requestUrl = UrlUtils.buildFullRequestUrl(request.getScheme(), request.getServerName(),
request.getServerPort(), request.getRequestURI(), null);
MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
if (requestUrl.equals(authorizationRequest.getRedirectUri()) &&
OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) {
// Compare redirect_uri
UriComponents requestUri = UriComponentsBuilder.fromUriString(UrlUtils.buildFullRequestUrl(request)).build();
UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getRedirectUri()).build();
Set<Map.Entry<String, List<String>>> requestUriParameters = new LinkedHashSet<>(requestUri.getQueryParams().entrySet());
Set<Map.Entry<String, List<String>>> redirectUriParameters = new LinkedHashSet<>(redirectUri.getQueryParams().entrySet());
// Remove the additional request parameters (if any) from the authorization response (request)
// before doing an exact comparison with the authorizationRequest.getRedirectUri() parameters (if any)
requestUriParameters.retainAll(redirectUriParameters);
if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) &&
Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) &&
Objects.equals(requestUri.getHost(), redirectUri.getHost()) &&
Objects.equals(requestUri.getPort(), redirectUri.getPort()) &&
Objects.equals(requestUri.getPath(), redirectUri.getPath()) &&
Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) {
return true;
}
return false;
@ -165,10 +186,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
.replaceQuery(null)
.build()
.toUriString();
String redirectUri = UrlUtils.buildFullRequestUrl(request);
OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri);
OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken(
@ -183,7 +201,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
} catch (OAuth2AuthorizationException ex) {
OAuth2Error error = ex.getError();
UriComponentsBuilder uriBuilder = UriComponentsBuilder
.fromUriString(authorizationResponse.getRedirectUri())
.fromUriString(authorizationRequest.getRedirectUri())
.queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode());
if (!StringUtils.isEmpty(error.getDescription())) {
uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription());
@ -206,7 +224,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response);
String redirectUrl = authorizationResponse.getRedirectUri();
String redirectUrl = authorizationRequest.getRedirectUri();
SavedRequest savedRequest = this.requestCache.getRequest(request, response);
if (savedRequest != null) {
redirectUrl = savedRequest.getRedirectUrl();

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,17 +15,9 @@
*/
package org.springframework.security.oauth2.client.web;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
@ -50,13 +42,26 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.CollectionUtils;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken;
@ -131,8 +136,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
// NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter.
HttpServletResponse response = mock(HttpServletResponse.class);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
@ -142,94 +146,142 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@Test
public void doFilterWhenAuthorizationRequestNotFoundThenNotProcessed() throws Exception {
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
HttpServletResponse response = mock(HttpServletResponse.class);
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenAuthorizationResponseUrlDoesNotMatchAuthorizationRequestRedirectUriThenNotProcessed() throws Exception {
String requestUri = "/callback/client-1";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
HttpServletResponse response = mock(HttpServletResponse.class);
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1);
request.setRequestURI(requestUri + "-no-match");
this.filter.doFilter(request, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
String requestUri = "/callback/client-1";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/path");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1);
this.filter.doFilter(authorizationResponse, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenAuthorizationRequestRedirectUriDoesNotMatchThenNotProcessed() throws Exception {
String requestUri = "/callback/client-1";
MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri);
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
authorizationResponse.setRequestURI(requestUri + "-no-match");
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(authorizationResponse, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
// gh-7963
@Test
public void doFilterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() throws Exception {
// 1) redirect_uri with query parameters
String requestUri = "/callback/client-1";
Map<String, String> parameters = new LinkedHashMap<>();
parameters.put("param1", "value1");
parameters.put("param2", "value2");
MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
MockHttpServletResponse response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
FilterChain filterChain = mock(FilterChain.class);
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
this.filter.doFilter(authorizationResponse, response, filterChain);
verifyNoInteractions(filterChain);
// 2) redirect_uri with query parameters AND authorization response additional parameters
Map<String, String> additionalParameters = new LinkedHashMap<>();
additionalParameters.put("auth-param1", "value1");
additionalParameters.put("auth-param2", "value2");
response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters);
this.filter.doFilter(authorizationResponse, response, filterChain);
verifyNoInteractions(filterChain);
}
// gh-7963
@Test
public void doFilterWhenAuthorizationRequestRedirectUriParametersDoesNotMatchThenNotProcessed() throws Exception {
String requestUri = "/callback/client-1";
Map<String, String> parameters = new LinkedHashMap<>();
parameters.put("param1", "value1");
parameters.put("param2", "value2");
MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
MockHttpServletResponse response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
FilterChain filterChain = mock(FilterChain.class);
// 1) Parameter value
Map<String, String> parametersNotMatch = new LinkedHashMap<>(parameters);
parametersNotMatch.put("param2", "value8");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(
createAuthorizationRequest(requestUri, parametersNotMatch));
authorizationResponse.setSession(authorizationRequest.getSession());
this.filter.doFilter(authorizationResponse, response, filterChain);
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
// 2) Parameter order
parametersNotMatch = new LinkedHashMap<>();
parametersNotMatch.put("param2", "value2");
parametersNotMatch.put("param1", "value1");
authorizationResponse = createAuthorizationResponse(
createAuthorizationRequest(requestUri, parametersNotMatch));
authorizationResponse.setSession(authorizationRequest.getSession());
this.filter.doFilter(authorizationResponse, response, filterChain);
verify(filterChain, times(2)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
// 3) Parameter missing
parametersNotMatch = new LinkedHashMap<>(parameters);
parametersNotMatch.remove("param2");
authorizationResponse = createAuthorizationResponse(
createAuthorizationRequest(requestUri, parametersNotMatch));
authorizationResponse.setSession(authorizationRequest.getSession());
this.filter.doFilter(authorizationResponse, response, filterChain);
verify(filterChain, times(3)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenAuthorizationRequestMatchThenAuthorizationRequestRemoved() throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
this.filter.doFilter(request, response, filterChain);
this.filter.doFilter(authorizationResponse, response, filterChain);
assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull();
assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(authorizationResponse)).isNull();
}
@Test
public void doFilterWhenAuthorizationFailsThenHandleOAuth2AuthorizationException() throws Exception {
String requestUri = "/callback/client-1";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthorizationRequest(request, response, this.registration1);
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT);
when(this.authenticationManager.authenticate(any(Authentication.class)))
.thenThrow(new OAuth2AuthorizationException(error));
this.filter.doFilter(request, response, filterChain);
this.filter.doFilter(authorizationResponse, response, filterChain);
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1?error=invalid_grant");
}
@Test
public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService() throws Exception {
String requestUri = "/callback/client-1";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
public void doFilterWhenAuthorizationSucceedsThenAuthorizedClientSavedToService() throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
this.filter.doFilter(request, response, filterChain);
this.filter.doFilter(authorizationResponse, response, filterChain);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
this.registration1.getRegistrationId(), this.principalName1);
@ -241,40 +293,31 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
}
@Test
public void doFilterWhenAuthorizationResponseSuccessThenRedirected() throws Exception {
String requestUri = "/callback/client-1";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
public void doFilterWhenAuthorizationSucceedsThenRedirected() throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
this.filter.doFilter(request, response, filterChain);
this.filter.doFilter(authorizationResponse, response, filterChain);
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
}
@Test
public void doFilterWhenAuthorizationResponseSuccessHasSavedRequestThenRedirectedToSavedRequest() throws Exception {
public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception {
String requestUri = "/saved-request";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
RequestCache requestCache = new HttpSessionRequestCache();
requestCache.saveRequest(request, response);
requestUri = "/callback/client-1";
request.setRequestURI(requestUri);
request.setRequestURI("/callback/client-1");
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
@ -284,36 +327,30 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
}
@Test
public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
AnonymousAuthenticationToken anonymousPrincipal =
new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(anonymousPrincipal);
SecurityContextHolder.setContext(securityContext);
String requestUri = "/callback/client-1";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
this.filter.doFilter(request, response, filterChain);
this.filter.doFilter(authorizationResponse, response, filterChain);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
this.registration1.getRegistrationId(), anonymousPrincipal, request);
this.registration1.getRegistrationId(), anonymousPrincipal, authorizationResponse);
assertThat(authorizedClient).isNotNull();
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName());
assertThat(authorizedClient.getAccessToken()).isNotNull();
HttpSession session = request.getSession(false);
HttpSession session = authorizationResponse.getSession(false);
assertThat(session).isNotNull();
@SuppressWarnings("unchecked")
@ -325,33 +362,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
}
@Test
public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
SecurityContextHolder.setContext(securityContext); // null Authentication
String requestUri = "/callback/client-1";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(request, response, this.registration1);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
this.filter.doFilter(request, response, filterChain);
this.filter.doFilter(authorizationResponse, response, filterChain);
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
this.registration1.getRegistrationId(), null, request);
this.registration1.getRegistrationId(), null, authorizationResponse);
assertThat(authorizedClient).isNotNull();
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser");
assertThat(authorizedClient.getAccessToken()).isNotNull();
HttpSession session = request.getSession(false);
HttpSession session = authorizationResponse.getSession(false);
assertThat(session).isNotNull();
@SuppressWarnings("unchecked")
@ -362,13 +393,51 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient);
}
private static MockHttpServletRequest createAuthorizationRequest(String requestUri) {
return createAuthorizationRequest(requestUri, new LinkedHashMap<>());
}
private static MockHttpServletRequest createAuthorizationRequest(String requestUri, Map<String, String> parameters) {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
if (!CollectionUtils.isEmpty(parameters)) {
parameters.forEach(request::addParameter);
request.setQueryString(
parameters.entrySet().stream()
.map(e -> e.getKey() + "=" + e.getValue())
.collect(Collectors.joining("&")));
}
return request;
}
private static MockHttpServletRequest createAuthorizationResponse(MockHttpServletRequest authorizationRequest) {
return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>());
}
private static MockHttpServletRequest createAuthorizationResponse(
MockHttpServletRequest authorizationRequest, Map<String, String> additionalParameters) {
MockHttpServletRequest authorizationResponse = new MockHttpServletRequest(
authorizationRequest.getMethod(), authorizationRequest.getRequestURI());
authorizationResponse.setServletPath(authorizationRequest.getRequestURI());
authorizationRequest.getParameterMap().forEach(authorizationResponse::addParameter);
authorizationResponse.addParameter(OAuth2ParameterNames.CODE, "code");
authorizationResponse.addParameter(OAuth2ParameterNames.STATE, "state");
additionalParameters.forEach(authorizationResponse::addParameter);
authorizationResponse.setQueryString(
authorizationResponse.getParameterMap().entrySet().stream()
.map(e -> e.getKey() + "=" + e.getValue()[0])
.collect(Collectors.joining("&")));
authorizationResponse.setSession(authorizationRequest.getSession());
return authorizationResponse;
}
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
ClientRegistration registration) {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = request()
.additionalParameters(additionalParameters)
.redirectUri(request.getRequestURL().toString()).build();
.attributes(attributes)
.redirectUri(UrlUtils.buildFullRequestUrl(request)).build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
}