diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java index 04a88be9c3..4cb47f575c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java @@ -178,9 +178,10 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce .toUriString(); OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri); + Object authenticationDetails = this.authenticationDetailsSource.buildDetails(request); OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken( clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); - authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); + authenticationRequest.setDetails(authenticationDetails); OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) this.getAuthenticationManager().authenticate(authenticationRequest); @@ -189,6 +190,7 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce authenticationResult.getPrincipal(), authenticationResult.getAuthorities(), authenticationResult.getClientRegistration().getRegistrationId()); + oauth2Authentication.setDetails(authenticationDetails); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( authenticationResult.getClientRegistration(), diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java index a7c25871f9..8bc1551527 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java @@ -27,6 +27,7 @@ import org.mockito.ArgumentCaptor; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -50,6 +51,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.WebAuthenticationDetails; import org.springframework.security.web.util.UrlUtils; import org.springframework.web.util.UriComponentsBuilder; @@ -79,6 +81,7 @@ public class OAuth2LoginAuthenticationFilterTests { private AuthorizationRequestRepository authorizationRequestRepository; private AuthenticationFailureHandler failureHandler; private AuthenticationManager authenticationManager; + private AuthenticationDetailsSource authenticationDetailsSource; private OAuth2LoginAuthenticationToken loginAuthentication; private OAuth2LoginAuthenticationFilter filter; @@ -93,11 +96,13 @@ public class OAuth2LoginAuthenticationFilterTests { this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); this.failureHandler = mock(AuthenticationFailureHandler.class); this.authenticationManager = mock(AuthenticationManager.class); + this.authenticationDetailsSource = mock(AuthenticationDetailsSource.class); this.filter = spy(new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, this.authorizedClientRepository, OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI)); this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository); this.filter.setAuthenticationFailureHandler(this.failureHandler); this.filter.setAuthenticationManager(this.authenticationManager); + this.filter.setAuthenticationDetailsSource(this.authenticationDetailsSource); } @Test @@ -400,6 +405,29 @@ public class OAuth2LoginAuthenticationFilterTests { assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri); } + // gh-6866 + @Test + public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationResult() throws Exception { + String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); + String state = "state"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, state); + + WebAuthenticationDetails webAuthenticationDetails = mock(WebAuthenticationDetails.class); + when(authenticationDetailsSource.buildDetails(any())).thenReturn(webAuthenticationDetails); + + MockHttpServletResponse response = new MockHttpServletResponse(); + + this.setUpAuthorizationRequest(request, response, this.registration2, state); + this.setUpAuthenticationResult(this.registration2); + + Authentication result = this.filter.attemptAuthentication(request, response); + + assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails); + } + private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, ClientRegistration registration, String state) { Map attributes = new HashMap<>();