From f843da194296f3c4aa773be81f5951172265be89 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Wed, 1 Aug 2018 15:49:34 -0500 Subject: [PATCH] Add OAuth2LoginAuthenticationWebFilter This is necessary so that the saving of the authorized client occurs outside of the ReactiveAuthenticationManager. It will allow for saving with the ServerWebExchange when ReactiveOAuth2AuthorizedClientRepository is added. Issue: gh-5621 --- .../config/web/server/ServerHttpSecurity.java | 75 +++++++----- ...th2LoginReactiveAuthenticationManager.java | 23 +--- ...tionCodeReactiveAuthenticationManager.java | 26 +---- .../OAuth2LoginAuthenticationWebFilter.java | 70 +++++++++++ ...ginReactiveAuthenticationManagerTests.java | 22 +--- ...odeReactiveAuthenticationManagerTests.java | 23 +--- ...uth2LoginAuthenticationWebFilterTests.java | 110 ++++++++++++++++++ .../AuthenticationWebFilter.java | 2 +- 8 files changed, 243 insertions(+), 108 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilter.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilterTests.java diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index c899a5e4a7..c48c13270f 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -15,20 +15,6 @@ */ package org.springframework.security.config.web.server; -import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; - -import java.io.IOException; -import java.io.PrintWriter; -import java.io.StringWriter; -import java.security.interfaces.RSAPublicKey; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.core.Ordered; @@ -55,6 +41,9 @@ import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; import org.springframework.security.oauth2.client.web.server.ServerOAuth2LoginAuthenticationTokenConverter; +import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.server.resource.authentication.JwtReactiveAuthenticationManager; @@ -118,9 +107,22 @@ import org.springframework.web.cors.reactive.DefaultCorsProcessor; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; - import reactor.core.publisher.Mono; +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.security.interfaces.RSAPublicKey; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; + /** * A {@link ServerHttpSecurity} is similar to Spring Security's {@code HttpSecurity} but for WebFlux. @@ -445,7 +447,7 @@ public class ServerHttpSecurity { public class OAuth2LoginSpec { private ReactiveClientRegistrationRepository clientRegistrationRepository; - private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; public OAuth2LoginSpec clientRegistrationRepository(ReactiveClientRegistrationRepository clientRegistrationRepository) { this.clientRegistrationRepository = clientRegistrationRepository; @@ -453,7 +455,12 @@ public class ServerHttpSecurity { } public OAuth2LoginSpec authorizedClientService(ReactiveOAuth2AuthorizedClientService authorizedClientService) { - this.authorizedClientService = authorizedClientService; + this.authorizedClientRepository = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(authorizedClientService); + return this; + } + + public OAuth2LoginSpec authorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + this.authorizedClientRepository = authorizedClientRepository; return this; } @@ -468,22 +475,21 @@ public class ServerHttpSecurity { protected void configure(ServerHttpSecurity http) { ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(); - ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService(); + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository(); OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(clientRegistrationRepository); WebClientReactiveAuthorizationCodeTokenResponseClient client = new WebClientReactiveAuthorizationCodeTokenResponseClient(); ReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService(); - ReactiveAuthenticationManager manager = new OAuth2LoginReactiveAuthenticationManager(client, userService, - authorizedClientService); + ReactiveAuthenticationManager manager = new OAuth2LoginReactiveAuthenticationManager(client, userService); boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent( "org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader()); if (oidcAuthenticationProviderEnabled) { - OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager(client, new OidcReactiveOAuth2UserService(), authorizedClientService); + OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager(client, new OidcReactiveOAuth2UserService()); manager = new DelegatingReactiveAuthenticationManager(oidc, manager); } - AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(manager); + AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository); authenticationFilter.setRequiresAuthenticationMatcher(new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}")); authenticationFilter.setServerAuthenticationConverter(new ServerOAuth2LoginAuthenticationTokenConverter(clientRegistrationRepository)); @@ -532,14 +538,27 @@ public class ServerHttpSecurity { return this.clientRegistrationRepository; } + private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() { + ServerOAuth2AuthorizedClientRepository result = this.authorizedClientRepository; + if (result == null) { + result = getBeanOrNull(ServerOAuth2AuthorizedClientRepository.class); + } + if (result == null) { + ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService(); + if (authorizedClientService != null) { + result = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository( + authorizedClientService); + } + } + return result; + } + private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() { - if (this.authorizedClientService == null) { - this.authorizedClientService = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); + ReactiveOAuth2AuthorizedClientService service = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); + if (service == null) { + service = new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository()); } - if (this.authorizedClientService == null) { - this.authorizedClientService = new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository()); - } - return this.authorizedClientService; + return service; } private OAuth2LoginSpec() {} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java index eb2f161c77..3ff33b5ecf 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java @@ -66,20 +66,15 @@ public class OAuth2LoginReactiveAuthenticationManager implements private final ReactiveOAuth2UserService userService; - private final ReactiveOAuth2AuthorizedClientService authorizedClientService; - private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); public OAuth2LoginReactiveAuthenticationManager( ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient, - ReactiveOAuth2UserService userService, - ReactiveOAuth2AuthorizedClientService authorizedClientService) { + ReactiveOAuth2UserService userService) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); Assert.notNull(userService, "userService cannot be null"); - Assert.notNull(authorizedClientService, "authorizedClientService"); this.accessTokenResponseClient = accessTokenResponseClient; this.userService = userService; - this.authorizedClientService = authorizedClientService; } @Override @@ -108,13 +103,13 @@ public class OAuth2LoginReactiveAuthenticationManager implements }); } - private Mono authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { + private Mono authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); Map additionalParameters = accessTokenResponse.getAdditionalParameters(); OAuth2UserRequest userRequest = new OAuth2UserRequest( authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters); return this.userService.loadUser(userRequest) - .flatMap(oauth2User -> { + .map(oauth2User -> { Collection mappedAuthorities = this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); @@ -125,17 +120,7 @@ public class OAuth2LoginReactiveAuthenticationManager implements mappedAuthorities, accessToken, accessTokenResponse.getRefreshToken()); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - authenticationResult.getClientRegistration(), - authenticationResult.getName(), - authenticationResult.getAccessToken(), - authenticationResult.getRefreshToken()); - OAuth2AuthenticationToken result = new OAuth2AuthenticationToken( - authenticationResult.getPrincipal(), - authenticationResult.getAuthorities(), - authenticationResult.getClientRegistration().getRegistrationId()); - return this.authorizedClientService.saveAuthorizedClient(authorizedClient, authenticationResult) - .thenReturn(result); + return authenticationResult; }); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java index 8f9e4bcbba..c3a373ded9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java @@ -19,9 +19,6 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; @@ -85,22 +82,17 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements private final ReactiveOAuth2UserService userService; - private final ReactiveOAuth2AuthorizedClientService authorizedClientService; - private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); private Function decoderFactory = new DefaultDecoderFactory(); public OidcAuthorizationCodeReactiveAuthenticationManager( ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient, - ReactiveOAuth2UserService userService, - ReactiveOAuth2AuthorizedClientService authorizedClientService) { + ReactiveOAuth2UserService userService) { Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); Assert.notNull(userService, "userService cannot be null"); - Assert.notNull(authorizedClientService, "authorizedClientService"); this.accessTokenResponseClient = accessTokenResponseClient; this.userService = userService; - this.authorizedClientService = authorizedClientService; } @Override @@ -157,7 +149,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements this.decoderFactory = decoderFactory; } - private Mono authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { + private Mono authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); Map additionalParameters = accessTokenResponse.getAdditionalParameters(); @@ -173,26 +165,16 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements return createOidcToken(clientRegistration, accessTokenResponse) .map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters)) .flatMap(this.userService::loadUser) - .flatMap(oauth2User -> { + .map(oauth2User -> { Collection mappedAuthorities = this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); - OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( + return new OAuth2LoginAuthenticationToken( authorizationCodeAuthentication.getClientRegistration(), authorizationCodeAuthentication.getAuthorizationExchange(), oauth2User, mappedAuthorities, accessToken); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - authenticationResult.getClientRegistration(), - authenticationResult.getName(), - authenticationResult.getAccessToken()); - OAuth2AuthenticationToken result = new OAuth2AuthenticationToken( - authenticationResult.getPrincipal(), - authenticationResult.getAuthorities(), - authenticationResult.getClientRegistration().getRegistrationId()); - return this.authorizedClientService.saveAuthorizedClient(authorizedClient, authenticationResult) - .thenReturn(result); }); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilter.java new file mode 100644 index 0000000000..a55d0dfe6d --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilter.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server.authentication; + +import org.springframework.security.authentication.ReactiveAuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.web.server.WebFilterExchange; +import org.springframework.security.web.server.authentication.AuthenticationWebFilter; +import org.springframework.util.Assert; +import reactor.core.publisher.Mono; + +/** + * A specialized {@link AuthenticationWebFilter} that converts from an {@link OAuth2LoginAuthenticationToken} to an + * {@link OAuth2AuthenticationToken} and saves the {@link OAuth2AuthorizedClient} + * + * @author Rob Winch + * @since 5.1 + */ +public class OAuth2LoginAuthenticationWebFilter extends AuthenticationWebFilter { + + private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + /** + * Creates an instance + * + * @param authenticationManager the authentication manager to use + * @param authorizedClientRepository + */ + public OAuth2LoginAuthenticationWebFilter( + ReactiveAuthenticationManager authenticationManager, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + super(authenticationManager); + Assert.notNull(authorizedClientRepository, "authorizedClientService cannot be null"); + this.authorizedClientRepository = authorizedClientRepository; + } + + @Override + protected Mono onAuthenticationSuccess(Authentication authentication, + WebFilterExchange webFilterExchange) { + OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) authentication; + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + authenticationResult.getClientRegistration(), + authenticationResult.getName(), + authenticationResult.getAccessToken(), + authenticationResult.getRefreshToken()); + OAuth2AuthenticationToken result = new OAuth2AuthenticationToken( + authenticationResult.getPrincipal(), + authenticationResult.getAuthorities(), + authenticationResult.getClientRegistration().getRegistrationId()); + return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authenticationResult, webFilterExchange.getExchange()) + .then(super.onAuthenticationSuccess(result, webFilterExchange)); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java index 943aec2023..da2707b678 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java @@ -42,8 +42,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; @@ -81,32 +79,20 @@ public class OAuth2LoginReactiveAuthenticationManagerTests { @Before public void setup() { - this.manager = new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService, - this.authorizedClientService); - when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + this.manager = new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService); } @Test public void constructorWhenNullAccessTokenResponseClientThenIllegalArgumentException() { this.accessTokenResponseClient = null; - assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService, - this.authorizedClientService)) + assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)) .isInstanceOf(IllegalArgumentException.class); } @Test public void constructorWhenNullUserServiceThenIllegalArgumentException() { this.userService = null; - assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService, - this.authorizedClientService)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void constructorWhenNullAuthorizedClientServiceThenIllegalArgumentException() { - this.authorizedClientService = null; - assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService, - this.authorizedClientService)) + assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)) .isInstanceOf(IllegalArgumentException.class); } @@ -164,7 +150,7 @@ public class OAuth2LoginReactiveAuthenticationManagerTests { DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user"); when(this.userService.loadUser(any())).thenReturn(Mono.just(user)); - OAuth2AuthenticationToken result = (OAuth2AuthenticationToken) this.manager.authenticate(loginToken()).block(); + OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block(); assertThat(result.getPrincipal()).isEqualTo(user); assertThat(result.getAuthorities()).containsOnlyElementsOf(user.getAuthorities()); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java index c28e729109..452f729094 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java @@ -24,7 +24,6 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; @@ -72,9 +71,6 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests { @Mock private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; - @Mock - private ReactiveOAuth2AuthorizedClientService authorizedClientService; - @Mock private ReactiveJwtDecoder jwtDecoder; @@ -92,33 +88,20 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests { @Before public void setup() { - this.manager = new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService, - this.authorizedClientService); - when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn( - Mono.empty()); + this.manager = new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService); } @Test public void constructorWhenNullAccessTokenResponseClientThenIllegalArgumentException() { this.accessTokenResponseClient = null; - assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService, - this.authorizedClientService)) + assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)) .isInstanceOf(IllegalArgumentException.class); } @Test public void constructorWhenNullUserServiceThenIllegalArgumentException() { this.userService = null; - assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService, - this.authorizedClientService)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void constructorWhenNullAuthorizedClientServiceThenIllegalArgumentException() { - this.authorizedClientService = null; - assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService, - this.authorizedClientService)) + assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService)) .isInstanceOf(IllegalArgumentException.class); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilterTests.java new file mode 100644 index 0000000000..0933f1076e --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/authentication/OAuth2LoginAuthenticationWebFilterTests.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.server.authentication; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.authentication.ReactiveAuthenticationManager; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; +import org.springframework.security.web.server.WebFilterExchange; +import org.springframework.web.server.handler.DefaultWebFilterChain; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + * @since 5.1 + */ +@RunWith(MockitoJUnitRunner.class) +public class OAuth2LoginAuthenticationWebFilterTests { + @Mock + private ReactiveAuthenticationManager authenticationManager; + @Mock + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private OAuth2LoginAuthenticationWebFilter filter; + private WebFilterExchange webFilterExchange; + + + private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration(); + + + private OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse + .success("code") + .state("state"); + + @Before + public void setup() { + this.filter = new OAuth2LoginAuthenticationWebFilter(this.authenticationManager, this.authorizedClientRepository); + this.webFilterExchange = new WebFilterExchange(MockServerWebExchange.from(MockServerHttpRequest.get("/")), new DefaultWebFilterChain(exchange -> exchange.getResponse().setComplete())); + when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) + .thenReturn(Mono.empty()); + } + + @Test + public void onAuthenticationSuccessWhenOAuth2LoginAuthenticationTokenThenSavesAuthorizedClient() { + this.filter.onAuthenticationSuccess(loginToken(), this.webFilterExchange).block(); + + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any()); + } + + private OAuth2LoginAuthenticationToken loginToken() { + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "token", + Instant.now(), + Instant.now().plus(Duration.ofDays(1)), + Collections.singleton("user")); + DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections + .singletonMap("user", "rob"), "user"); + ClientRegistration clientRegistration = this.registration.build(); + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest + .authorizationCode() + .state("state") + .clientId(clientRegistration.getClientId()) + .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) + .redirectUri(clientRegistration.getRedirectUriTemplate()) + .scopes(clientRegistration.getScopes()) + .build(); + OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBldr + .redirectUri(clientRegistration.getRedirectUriTemplate()) + .build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + authorizationResponse); + return new OAuth2LoginAuthenticationToken(clientRegistration, authorizationExchange, user, user.getAuthorities(), accessToken); + } +} diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java index 6f257a38f5..4fc59f87d6 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java @@ -103,7 +103,7 @@ public class AuthenticationWebFilter implements WebFilter { .onAuthenticationFailure(webFilterExchange, e)); } - private Mono onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) { + protected Mono onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) { ServerWebExchange exchange = webFilterExchange.getExchange(); SecurityContextImpl securityContext = new SecurityContextImpl(); securityContext.setAuthentication(authentication);