ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId
Issue: gh-4921
This commit is contained in:
parent
28537fa3b6
commit
158b8aa6d5
|
@ -27,12 +27,15 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|||
import org.springframework.security.core.context.SecurityContext;
|
||||
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.OAuth2ClientException;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
|
||||
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
|
||||
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.reactive.function.BodyInserters;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
|
@ -75,18 +78,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|||
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
|
||||
*/
|
||||
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
|
||||
public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
|
||||
|
||||
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
|
||||
AuthorityUtils.createAuthorityList("ROLE_USER"));
|
||||
|
||||
private Clock clock = Clock.systemUTC();
|
||||
|
||||
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
|
||||
|
||||
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
|
||||
new WebClientReactiveClientCredentialsTokenResponseClient();
|
||||
|
||||
private ReactiveClientRegistrationRepository clientRegistrationRepository;
|
||||
|
||||
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
||||
|
||||
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
|
||||
|
||||
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
||||
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
||||
this.clientRegistrationRepository = clientRegistrationRepository;
|
||||
this.authorizedClientRepository = authorizedClientRepository;
|
||||
}
|
||||
|
||||
|
@ -164,6 +174,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|||
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
|
||||
* client_credentials grant.
|
||||
* @param clientCredentialsTokenResponseClient the client to use
|
||||
*/
|
||||
public void setClientCredentialsTokenResponseClient(
|
||||
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
|
||||
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
|
||||
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
|
||||
}
|
||||
|
||||
/**
|
||||
* An access token will be considered expired by comparing its expiration to now +
|
||||
* this skewed Duration. The default is 1 minute.
|
||||
|
@ -208,7 +229,39 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|||
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
|
||||
ServerWebExchange exchange, Authentication principal) {
|
||||
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
|
||||
.switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)));
|
||||
.switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange));
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) {
|
||||
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
|
||||
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
|
||||
.flatMap(clientRegistration -> {
|
||||
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
|
||||
return clientCredentials(clientRegistration, exchange);
|
||||
}
|
||||
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
|
||||
});
|
||||
}
|
||||
|
||||
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
|
||||
ClientRegistration clientRegistration, ServerWebExchange exchange) {
|
||||
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
|
||||
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
|
||||
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange));
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) {
|
||||
return currentAuthentication()
|
||||
.flatMap(principal -> {
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
|
||||
clientRegistration, (principal != null ?
|
||||
principal.getName() :
|
||||
"anonymousUser"),
|
||||
tokenResponse.getAccessToken());
|
||||
|
||||
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null)
|
||||
.thenReturn(authorizedClient);
|
||||
});
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
|
|||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
|
@ -71,7 +72,10 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c
|
|||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||
@Mock
|
||||
private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
|
||||
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
||||
|
||||
@Mock
|
||||
private ReactiveClientRegistrationRepository clientRegistrationRepository;
|
||||
|
||||
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
|
||||
|
||||
|
@ -125,7 +129,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
|
||||
@Test
|
||||
public void filterWhenRefreshRequiredThenRefresh() {
|
||||
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
|
||||
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
|
||||
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
|
@ -140,7 +144,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
this.accessToken.getTokenValue(),
|
||||
issuedAt,
|
||||
accessTokenExpiresAt);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
|
||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
|
@ -154,7 +158,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
|
||||
.block();
|
||||
|
||||
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
|
||||
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(2);
|
||||
|
@ -174,7 +178,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
|
||||
@Test
|
||||
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
|
||||
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
|
||||
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
|
||||
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
|
||||
.tokenType(OAuth2AccessToken.TokenType.BEARER)
|
||||
.expiresIn(3600)
|
||||
|
@ -189,7 +193,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
this.accessToken.getTokenValue(),
|
||||
issuedAt,
|
||||
accessTokenExpiresAt);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
|
||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
|
@ -201,7 +205,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
this.function.filter(request, this.exchange)
|
||||
.block();
|
||||
|
||||
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
|
||||
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any());
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(2);
|
||||
|
@ -221,7 +225,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
|
||||
@Test
|
||||
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
"principalName", this.accessToken);
|
||||
|
@ -243,7 +247,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
|
||||
@Test
|
||||
public void filterWhenNotExpiredThenShouldRefreshFalse() {
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
|
||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
|
@ -266,12 +270,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
|
||||
@Test
|
||||
public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
|
||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
"principalName", this.accessToken, refreshToken);
|
||||
when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
|
||||
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
|
||||
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||
.attributes(clientRegistrationId(this.registration.getRegistrationId()))
|
||||
.build();
|
||||
|
|
|
@ -18,7 +18,9 @@ package sample.config;
|
|||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||
import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction;
|
||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
|
||||
/**
|
||||
|
@ -29,9 +31,10 @@ import org.springframework.web.reactive.function.client.WebClient;
|
|||
public class WebClientConfig {
|
||||
|
||||
@Bean
|
||||
WebClient webClient() {
|
||||
WebClient webClient(ReactiveClientRegistrationRepository clientRegistrationRepository,
|
||||
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
||||
return WebClient.builder()
|
||||
.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction())
|
||||
.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue