Add OAuth2LoginAuthenticationWebFilter

This is necessary so that the saving of the authorized client occurs
outside of the ReactiveAuthenticationManager. It will allow for
saving with the ServerWebExchange when ReactiveOAuth2AuthorizedClientRepository
is added.

Issue: gh-5621
This commit is contained in:
Rob Winch 2018-08-01 15:49:34 -05:00
parent dd7925cb63
commit f843da1942
8 changed files with 243 additions and 108 deletions

View File

@ -15,20 +15,6 @@
*/
package org.springframework.security.config.web.server;
import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.core.Ordered;
@ -55,6 +41,9 @@ import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2LoginAuthenticationTokenConverter;
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.server.resource.authentication.JwtReactiveAuthenticationManager;
@ -118,9 +107,22 @@ import org.springframework.web.cors.reactive.DefaultCorsProcessor;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry;
/**
* A {@link ServerHttpSecurity} is similar to Spring Security's {@code HttpSecurity} but for WebFlux.
@ -445,7 +447,7 @@ public class ServerHttpSecurity {
public class OAuth2LoginSpec {
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
public OAuth2LoginSpec clientRegistrationRepository(ReactiveClientRegistrationRepository clientRegistrationRepository) {
this.clientRegistrationRepository = clientRegistrationRepository;
@ -453,7 +455,12 @@ public class ServerHttpSecurity {
}
public OAuth2LoginSpec authorizedClientService(ReactiveOAuth2AuthorizedClientService authorizedClientService) {
this.authorizedClientService = authorizedClientService;
this.authorizedClientRepository = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(authorizedClientService);
return this;
}
public OAuth2LoginSpec authorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this.authorizedClientRepository = authorizedClientRepository;
return this;
}
@ -468,22 +475,21 @@ public class ServerHttpSecurity {
protected void configure(ServerHttpSecurity http) {
ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository();
ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService();
ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository();
OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(clientRegistrationRepository);
WebClientReactiveAuthorizationCodeTokenResponseClient client = new WebClientReactiveAuthorizationCodeTokenResponseClient();
ReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService();
ReactiveAuthenticationManager manager = new OAuth2LoginReactiveAuthenticationManager(client, userService,
authorizedClientService);
ReactiveAuthenticationManager manager = new OAuth2LoginReactiveAuthenticationManager(client, userService);
boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent(
"org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader());
if (oidcAuthenticationProviderEnabled) {
OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager(client, new OidcReactiveOAuth2UserService(), authorizedClientService);
OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager(client, new OidcReactiveOAuth2UserService());
manager = new DelegatingReactiveAuthenticationManager(oidc, manager);
}
AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(manager);
AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository);
authenticationFilter.setRequiresAuthenticationMatcher(new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}"));
authenticationFilter.setServerAuthenticationConverter(new ServerOAuth2LoginAuthenticationTokenConverter(clientRegistrationRepository));
@ -532,14 +538,27 @@ public class ServerHttpSecurity {
return this.clientRegistrationRepository;
}
private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() {
ServerOAuth2AuthorizedClientRepository result = this.authorizedClientRepository;
if (result == null) {
result = getBeanOrNull(ServerOAuth2AuthorizedClientRepository.class);
}
if (result == null) {
ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService();
if (authorizedClientService != null) {
result = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(
authorizedClientService);
}
}
return result;
}
private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() {
if (this.authorizedClientService == null) {
this.authorizedClientService = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class);
ReactiveOAuth2AuthorizedClientService service = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class);
if (service == null) {
service = new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository());
}
if (this.authorizedClientService == null) {
this.authorizedClientService = new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository());
}
return this.authorizedClientService;
return service;
}
private OAuth2LoginSpec() {}

View File

@ -66,20 +66,15 @@ public class OAuth2LoginReactiveAuthenticationManager implements
private final ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> userService;
private final ReactiveOAuth2AuthorizedClientService authorizedClientService;
private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
public OAuth2LoginReactiveAuthenticationManager(
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient,
ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> userService,
ReactiveOAuth2AuthorizedClientService authorizedClientService) {
ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> userService) {
Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
Assert.notNull(userService, "userService cannot be null");
Assert.notNull(authorizedClientService, "authorizedClientService");
this.accessTokenResponseClient = accessTokenResponseClient;
this.userService = userService;
this.authorizedClientService = authorizedClientService;
}
@Override
@ -108,13 +103,13 @@ public class OAuth2LoginReactiveAuthenticationManager implements
});
}
private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
private Mono<OAuth2LoginAuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
OAuth2UserRequest userRequest = new OAuth2UserRequest(
authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters);
return this.userService.loadUser(userRequest)
.flatMap(oauth2User -> {
.map(oauth2User -> {
Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
@ -125,17 +120,7 @@ public class OAuth2LoginReactiveAuthenticationManager implements
mappedAuthorities,
accessToken,
accessTokenResponse.getRefreshToken());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
authenticationResult.getClientRegistration(),
authenticationResult.getName(),
authenticationResult.getAccessToken(),
authenticationResult.getRefreshToken());
OAuth2AuthenticationToken result = new OAuth2AuthenticationToken(
authenticationResult.getPrincipal(),
authenticationResult.getAuthorities(),
authenticationResult.getClientRegistration().getRegistrationId());
return this.authorizedClientService.saveAuthorizedClient(authorizedClient, authenticationResult)
.thenReturn(result);
return authenticationResult;
});
}
}

View File

@ -19,9 +19,6 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
@ -85,22 +82,17 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
private final ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService;
private final ReactiveOAuth2AuthorizedClientService authorizedClientService;
private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
private Function<ClientRegistration, ReactiveJwtDecoder> decoderFactory = new DefaultDecoderFactory();
public OidcAuthorizationCodeReactiveAuthenticationManager(
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient,
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService,
ReactiveOAuth2AuthorizedClientService authorizedClientService) {
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService) {
Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
Assert.notNull(userService, "userService cannot be null");
Assert.notNull(authorizedClientService, "authorizedClientService");
this.accessTokenResponseClient = accessTokenResponseClient;
this.userService = userService;
this.authorizedClientService = authorizedClientService;
}
@Override
@ -157,7 +149,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
this.decoderFactory = decoderFactory;
}
private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
private Mono<OAuth2LoginAuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
@ -173,26 +165,16 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
return createOidcToken(clientRegistration, accessTokenResponse)
.map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters))
.flatMap(this.userService::loadUser)
.flatMap(oauth2User -> {
.map(oauth2User -> {
Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken(
return new OAuth2LoginAuthenticationToken(
authorizationCodeAuthentication.getClientRegistration(),
authorizationCodeAuthentication.getAuthorizationExchange(),
oauth2User,
mappedAuthorities,
accessToken);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
authenticationResult.getClientRegistration(),
authenticationResult.getName(),
authenticationResult.getAccessToken());
OAuth2AuthenticationToken result = new OAuth2AuthenticationToken(
authenticationResult.getPrincipal(),
authenticationResult.getAuthorities(),
authenticationResult.getClientRegistration().getRegistrationId());
return this.authorizedClientService.saveAuthorizedClient(authorizedClient, authenticationResult)
.thenReturn(result);
});
}

View File

@ -0,0 +1,70 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.server.authentication;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.AuthenticationWebFilter;
import org.springframework.util.Assert;
import reactor.core.publisher.Mono;
/**
* A specialized {@link AuthenticationWebFilter} that converts from an {@link OAuth2LoginAuthenticationToken} to an
* {@link OAuth2AuthenticationToken} and saves the {@link OAuth2AuthorizedClient}
*
* @author Rob Winch
* @since 5.1
*/
public class OAuth2LoginAuthenticationWebFilter extends AuthenticationWebFilter {
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
/**
* Creates an instance
*
* @param authenticationManager the authentication manager to use
* @param authorizedClientRepository
*/
public OAuth2LoginAuthenticationWebFilter(
ReactiveAuthenticationManager authenticationManager,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
super(authenticationManager);
Assert.notNull(authorizedClientRepository, "authorizedClientService cannot be null");
this.authorizedClientRepository = authorizedClientRepository;
}
@Override
protected Mono<Void> onAuthenticationSuccess(Authentication authentication,
WebFilterExchange webFilterExchange) {
OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) authentication;
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
authenticationResult.getClientRegistration(),
authenticationResult.getName(),
authenticationResult.getAccessToken(),
authenticationResult.getRefreshToken());
OAuth2AuthenticationToken result = new OAuth2AuthenticationToken(
authenticationResult.getPrincipal(),
authenticationResult.getAuthorities(),
authenticationResult.getClientRegistration().getRegistrationId());
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authenticationResult, webFilterExchange.getExchange())
.then(super.onAuthenticationSuccess(result, webFilterExchange));
}
}

View File

@ -42,8 +42,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@ -81,32 +79,20 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
@Before
public void setup() {
this.manager = new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService,
this.authorizedClientService);
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty());
this.manager = new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService);
}
@Test
public void constructorWhenNullAccessTokenResponseClientThenIllegalArgumentException() {
this.accessTokenResponseClient = null;
assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService,
this.authorizedClientService))
assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenNullUserServiceThenIllegalArgumentException() {
this.userService = null;
assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService,
this.authorizedClientService))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenNullAuthorizedClientServiceThenIllegalArgumentException() {
this.authorizedClientService = null;
assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService,
this.authorizedClientService))
assertThatThrownBy(() -> new OAuth2LoginReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService))
.isInstanceOf(IllegalArgumentException.class);
}
@ -164,7 +150,7 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user");
when(this.userService.loadUser(any())).thenReturn(Mono.just(user));
OAuth2AuthenticationToken result = (OAuth2AuthenticationToken) this.manager.authenticate(loginToken()).block();
OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block();
assertThat(result.getPrincipal()).isEqualTo(user);
assertThat(result.getAuthorities()).containsOnlyElementsOf(user.getAuthorities());

View File

@ -24,7 +24,6 @@ import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
@ -72,9 +71,6 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Mock
private ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
@Mock
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
@Mock
private ReactiveJwtDecoder jwtDecoder;
@ -92,33 +88,20 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Before
public void setup() {
this.manager = new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService,
this.authorizedClientService);
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(
Mono.empty());
this.manager = new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService);
}
@Test
public void constructorWhenNullAccessTokenResponseClientThenIllegalArgumentException() {
this.accessTokenResponseClient = null;
assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService,
this.authorizedClientService))
assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenNullUserServiceThenIllegalArgumentException() {
this.userService = null;
assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService,
this.authorizedClientService))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenNullAuthorizedClientServiceThenIllegalArgumentException() {
this.authorizedClientService = null;
assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService,
this.authorizedClientService))
assertThatThrownBy(() -> new OidcAuthorizationCodeReactiveAuthenticationManager(this.accessTokenResponseClient, this.userService))
.isInstanceOf(IllegalArgumentException.class);
}

View File

@ -0,0 +1,110 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.server.authentication;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.web.server.handler.DefaultWebFilterChain;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* @author Rob Winch
* @since 5.1
*/
@RunWith(MockitoJUnitRunner.class)
public class OAuth2LoginAuthenticationWebFilterTests {
@Mock
private ReactiveAuthenticationManager authenticationManager;
@Mock
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private OAuth2LoginAuthenticationWebFilter filter;
private WebFilterExchange webFilterExchange;
private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration();
private OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse
.success("code")
.state("state");
@Before
public void setup() {
this.filter = new OAuth2LoginAuthenticationWebFilter(this.authenticationManager, this.authorizedClientRepository);
this.webFilterExchange = new WebFilterExchange(MockServerWebExchange.from(MockServerHttpRequest.get("/")), new DefaultWebFilterChain(exchange -> exchange.getResponse().setComplete()));
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
.thenReturn(Mono.empty());
}
@Test
public void onAuthenticationSuccessWhenOAuth2LoginAuthenticationTokenThenSavesAuthorizedClient() {
this.filter.onAuthenticationSuccess(loginToken(), this.webFilterExchange).block();
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any());
}
private OAuth2LoginAuthenticationToken loginToken() {
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"token",
Instant.now(),
Instant.now().plus(Duration.ofDays(1)),
Collections.singleton("user"));
DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections
.singletonMap("user", "rob"), "user");
ClientRegistration clientRegistration = this.registration.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
.authorizationCode()
.state("state")
.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(clientRegistration.getRedirectUriTemplate())
.scopes(clientRegistration.getScopes())
.build();
OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBldr
.redirectUri(clientRegistration.getRedirectUriTemplate())
.build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
authorizationResponse);
return new OAuth2LoginAuthenticationToken(clientRegistration, authorizationExchange, user, user.getAuthorities(), accessToken);
}
}

View File

@ -103,7 +103,7 @@ public class AuthenticationWebFilter implements WebFilter {
.onAuthenticationFailure(webFilterExchange, e));
}
private Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {
protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {
ServerWebExchange exchange = webFilterExchange.getExchange();
SecurityContextImpl securityContext = new SecurityContextImpl();
securityContext.setAuthentication(authentication);