From 3c86239b39e5b000eab70a94dd91c61d578ef998 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 11 Nov 2019 06:36:02 -0500 Subject: [PATCH] OAuth2AuthorizationCodeGrantFilter matches on query parameters Fixes gh-7963 --- .../OAuth2AuthorizationCodeGrantFilter.java | 48 ++- ...uth2AuthorizationCodeGrantFilterTests.java | 307 +++++++++++------- 2 files changed, 221 insertions(+), 134 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java index 3eda6df7c6..4f8aaefbaf 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,6 +41,7 @@ import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import javax.servlet.FilterChain; @@ -48,6 +49,11 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; /** * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, @@ -132,24 +138,39 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - if (this.shouldProcessAuthorizationResponse(request)) { - this.processAuthorizationResponse(request, response); + if (matchesAuthorizationResponse(request)) { + processAuthorizationResponse(request, response); return; } filterChain.doFilter(request, response); } - private boolean shouldProcessAuthorizationResponse(HttpServletRequest request) { + private boolean matchesAuthorizationResponse(HttpServletRequest request) { + MultiValueMap params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap()); + if (!OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) { + return false; + } OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request); if (authorizationRequest == null) { return false; } - String requestUrl = UrlUtils.buildFullRequestUrl(request.getScheme(), request.getServerName(), - request.getServerPort(), request.getRequestURI(), null); - MultiValueMap params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap()); - if (requestUrl.equals(authorizationRequest.getRedirectUri()) && - OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) { + + // Compare redirect_uri + UriComponents requestUri = UriComponentsBuilder.fromUriString(UrlUtils.buildFullRequestUrl(request)).build(); + UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getRedirectUri()).build(); + Set>> requestUriParameters = new LinkedHashSet<>(requestUri.getQueryParams().entrySet()); + Set>> redirectUriParameters = new LinkedHashSet<>(redirectUri.getQueryParams().entrySet()); + // Remove the additional request parameters (if any) from the authorization response (request) + // before doing an exact comparison with the authorizationRequest.getRedirectUri() parameters (if any) + requestUriParameters.retainAll(redirectUriParameters); + + if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) && + Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) && + Objects.equals(requestUri.getHost(), redirectUri.getHost()) && + Objects.equals(requestUri.getPort(), redirectUri.getPort()) && + Objects.equals(requestUri.getPath(), redirectUri.getPath()) && + Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) { return true; } return false; @@ -165,10 +186,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); MultiValueMap params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap()); - String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) - .replaceQuery(null) - .build() - .toUriString(); + String redirectUri = UrlUtils.buildFullRequestUrl(request); OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri); OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken( @@ -183,7 +201,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { } catch (OAuth2AuthorizationException ex) { OAuth2Error error = ex.getError(); UriComponentsBuilder uriBuilder = UriComponentsBuilder - .fromUriString(authorizationResponse.getRedirectUri()) + .fromUriString(authorizationRequest.getRedirectUri()) .queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode()); if (!StringUtils.isEmpty(error.getDescription())) { uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription()); @@ -206,7 +224,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response); - String redirectUrl = authorizationResponse.getRedirectUri(); + String redirectUrl = authorizationRequest.getRedirectUri(); SavedRequest savedRequest = this.requestCache.getRequest(request, response); if (savedRequest != null) { redirectUrl = savedRequest.getRedirectUrl(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java index aaacad6b21..39b3011f03 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,17 +15,9 @@ */ package org.springframework.security.oauth2.client.web; -import java.util.HashMap; -import java.util.Map; -import javax.servlet.FilterChain; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; - import org.junit.After; import org.junit.Before; import org.junit.Test; - import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -50,13 +42,26 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.RequestCache; +import org.springframework.security.web.util.UrlUtils; +import org.springframework.util.CollectionUtils; + +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken; @@ -131,8 +136,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); // NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter. - - HttpServletResponse response = mock(HttpServletResponse.class); + MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -142,94 +146,142 @@ public class OAuth2AuthorizationCodeGrantFilterTests { @Test public void doFilterWhenAuthorizationRequestNotFoundThenNotProcessed() throws Exception { - String requestUri = "/path"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - - HttpServletResponse response = mock(HttpServletResponse.class); - FilterChain filterChain = mock(FilterChain.class); - - this.filter.doFilter(request, response, filterChain); - - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - } - - @Test - public void doFilterWhenAuthorizationResponseUrlDoesNotMatchAuthorizationRequestRedirectUriThenNotProcessed() throws Exception { - String requestUri = "/callback/client-1"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - - HttpServletResponse response = mock(HttpServletResponse.class); - FilterChain filterChain = mock(FilterChain.class); - - this.setUpAuthorizationRequest(request, response, this.registration1); - request.setRequestURI(requestUri + "-no-match"); - - this.filter.doFilter(request, response, filterChain); - - verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); - } - - @Test - public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception { - String requestUri = "/callback/client-1"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/path"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - this.setUpAuthorizationRequest(request, response, this.registration1); + this.filter.doFilter(authorizationResponse, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAuthorizationRequestRedirectUriDoesNotMatchThenNotProcessed() throws Exception { + String requestUri = "/callback/client-1"; + MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockHttpServletResponse response = new MockHttpServletResponse(); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); + authorizationResponse.setRequestURI(requestUri + "-no-match"); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(authorizationResponse, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + // gh-7963 + @Test + public void doFilterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() throws Exception { + // 1) redirect_uri with query parameters + String requestUri = "/callback/client-1"; + Map parameters = new LinkedHashMap<>(); + parameters.put("param1", "value1"); + parameters.put("param2", "value2"); + MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters); + MockHttpServletResponse response = new MockHttpServletResponse(); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); + this.setUpAuthenticationResult(this.registration1); + FilterChain filterChain = mock(FilterChain.class); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + this.filter.doFilter(authorizationResponse, response, filterChain); + verifyNoInteractions(filterChain); + + // 2) redirect_uri with query parameters AND authorization response additional parameters + Map additionalParameters = new LinkedHashMap<>(); + additionalParameters.put("auth-param1", "value1"); + additionalParameters.put("auth-param2", "value2"); + response = new MockHttpServletResponse(); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); + authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters); + this.filter.doFilter(authorizationResponse, response, filterChain); + verifyNoInteractions(filterChain); + } + + // gh-7963 + @Test + public void doFilterWhenAuthorizationRequestRedirectUriParametersDoesNotMatchThenNotProcessed() throws Exception { + String requestUri = "/callback/client-1"; + Map parameters = new LinkedHashMap<>(); + parameters.put("param1", "value1"); + parameters.put("param2", "value2"); + MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters); + MockHttpServletResponse response = new MockHttpServletResponse(); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); + this.setUpAuthenticationResult(this.registration1); + FilterChain filterChain = mock(FilterChain.class); + + // 1) Parameter value + Map parametersNotMatch = new LinkedHashMap<>(parameters); + parametersNotMatch.put("param2", "value8"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse( + createAuthorizationRequest(requestUri, parametersNotMatch)); + authorizationResponse.setSession(authorizationRequest.getSession()); + this.filter.doFilter(authorizationResponse, response, filterChain); + verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + // 2) Parameter order + parametersNotMatch = new LinkedHashMap<>(); + parametersNotMatch.put("param2", "value2"); + parametersNotMatch.put("param1", "value1"); + authorizationResponse = createAuthorizationResponse( + createAuthorizationRequest(requestUri, parametersNotMatch)); + authorizationResponse.setSession(authorizationRequest.getSession()); + this.filter.doFilter(authorizationResponse, response, filterChain); + verify(filterChain, times(2)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + + // 3) Parameter missing + parametersNotMatch = new LinkedHashMap<>(parameters); + parametersNotMatch.remove("param2"); + authorizationResponse = createAuthorizationResponse( + createAuthorizationRequest(requestUri, parametersNotMatch)); + authorizationResponse.setSession(authorizationRequest.getSession()); + this.filter.doFilter(authorizationResponse, response, filterChain); + verify(filterChain, times(3)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenAuthorizationRequestMatchThenAuthorizationRequestRemoved() throws Exception { + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(request, response, filterChain); + this.filter.doFilter(authorizationResponse, response, filterChain); - assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull(); + assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(authorizationResponse)).isNull(); } @Test public void doFilterWhenAuthorizationFailsThenHandleOAuth2AuthorizationException() throws Exception { - String requestUri = "/callback/client-1"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); - this.setUpAuthorizationRequest(request, response, this.registration1); OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT); when(this.authenticationManager.authenticate(any(Authentication.class))) .thenThrow(new OAuth2AuthorizationException(error)); - this.filter.doFilter(request, response, filterChain); + this.filter.doFilter(authorizationResponse, response, filterChain); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1?error=invalid_grant"); } @Test - public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService() throws Exception { - String requestUri = "/callback/client-1"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - + public void doFilterWhenAuthorizationSucceedsThenAuthorizedClientSavedToService() throws Exception { + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - - this.setUpAuthorizationRequest(request, response, this.registration1); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(request, response, filterChain); + this.filter.doFilter(authorizationResponse, response, filterChain); OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient( this.registration1.getRegistrationId(), this.principalName1); @@ -241,40 +293,31 @@ public class OAuth2AuthorizationCodeGrantFilterTests { } @Test - public void doFilterWhenAuthorizationResponseSuccessThenRedirected() throws Exception { - String requestUri = "/callback/client-1"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - + public void doFilterWhenAuthorizationSucceedsThenRedirected() throws Exception { + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - - this.setUpAuthorizationRequest(request, response, this.registration1); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(request, response, filterChain); + this.filter.doFilter(authorizationResponse, response, filterChain); assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1"); } @Test - public void doFilterWhenAuthorizationResponseSuccessHasSavedRequestThenRedirectedToSavedRequest() throws Exception { + public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception { String requestUri = "/saved-request"; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); RequestCache requestCache = new HttpSessionRequestCache(); requestCache.saveRequest(request, response); - - requestUri = "/callback/client-1"; - request.setRequestURI(requestUri); + request.setRequestURI("/callback/client-1"); request.addParameter(OAuth2ParameterNames.CODE, "code"); request.addParameter(OAuth2ParameterNames.STATE, "state"); - FilterChain filterChain = mock(FilterChain.class); - this.setUpAuthorizationRequest(request, response, this.registration1); this.setUpAuthenticationResult(this.registration1); @@ -284,36 +327,30 @@ public class OAuth2AuthorizationCodeGrantFilterTests { } @Test - public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception { + public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception { AnonymousAuthenticationToken anonymousPrincipal = new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); securityContext.setAuthentication(anonymousPrincipal); SecurityContextHolder.setContext(securityContext); - String requestUri = "/callback/client-1"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - - this.setUpAuthorizationRequest(request, response, this.registration1); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(request, response, filterChain); + this.filter.doFilter(authorizationResponse, response, filterChain); OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registration1.getRegistrationId(), anonymousPrincipal, request); + this.registration1.getRegistrationId(), anonymousPrincipal, authorizationResponse); assertThat(authorizedClient).isNotNull(); - assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1); assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName()); assertThat(authorizedClient.getAccessToken()).isNotNull(); - HttpSession session = request.getSession(false); + HttpSession session = authorizationResponse.getSession(false); assertThat(session).isNotNull(); @SuppressWarnings("unchecked") @@ -325,33 +362,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests { } @Test - public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception { + public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception { SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); SecurityContextHolder.setContext(securityContext); // null Authentication - String requestUri = "/callback/client-1"; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); - request.addParameter(OAuth2ParameterNames.CODE, "code"); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - - this.setUpAuthorizationRequest(request, response, this.registration1); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); this.setUpAuthenticationResult(this.registration1); - this.filter.doFilter(request, response, filterChain); + this.filter.doFilter(authorizationResponse, response, filterChain); OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - this.registration1.getRegistrationId(), null, request); + this.registration1.getRegistrationId(), null, authorizationResponse); assertThat(authorizedClient).isNotNull(); - assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1); assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser"); assertThat(authorizedClient.getAccessToken()).isNotNull(); - HttpSession session = request.getSession(false); + HttpSession session = authorizationResponse.getSession(false); assertThat(session).isNotNull(); @SuppressWarnings("unchecked") @@ -362,13 +393,51 @@ public class OAuth2AuthorizationCodeGrantFilterTests { assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient); } + private static MockHttpServletRequest createAuthorizationRequest(String requestUri) { + return createAuthorizationRequest(requestUri, new LinkedHashMap<>()); + } + + private static MockHttpServletRequest createAuthorizationRequest(String requestUri, Map parameters) { + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + if (!CollectionUtils.isEmpty(parameters)) { + parameters.forEach(request::addParameter); + request.setQueryString( + parameters.entrySet().stream() + .map(e -> e.getKey() + "=" + e.getValue()) + .collect(Collectors.joining("&"))); + } + return request; + } + + private static MockHttpServletRequest createAuthorizationResponse(MockHttpServletRequest authorizationRequest) { + return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>()); + } + + private static MockHttpServletRequest createAuthorizationResponse( + MockHttpServletRequest authorizationRequest, Map additionalParameters) { + MockHttpServletRequest authorizationResponse = new MockHttpServletRequest( + authorizationRequest.getMethod(), authorizationRequest.getRequestURI()); + authorizationResponse.setServletPath(authorizationRequest.getRequestURI()); + authorizationRequest.getParameterMap().forEach(authorizationResponse::addParameter); + authorizationResponse.addParameter(OAuth2ParameterNames.CODE, "code"); + authorizationResponse.addParameter(OAuth2ParameterNames.STATE, "state"); + additionalParameters.forEach(authorizationResponse::addParameter); + authorizationResponse.setQueryString( + authorizationResponse.getParameterMap().entrySet().stream() + .map(e -> e.getKey() + "=" + e.getValue()[0]) + .collect(Collectors.joining("&"))); + authorizationResponse.setSession(authorizationRequest.getSession()); + return authorizationResponse; + } + private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, ClientRegistration registration) { - Map additionalParameters = new HashMap<>(); - additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); OAuth2AuthorizationRequest authorizationRequest = request() - .additionalParameters(additionalParameters) - .redirectUri(request.getRequestURL().toString()).build(); + .attributes(attributes) + .redirectUri(UrlUtils.buildFullRequestUrl(request)).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); }