ServerOAuth2AuthorizedClientExchangeFilterFunction defaultOAuth2AuthorizedClient
Defaults to use the OAuth2AuthenticationToken to resolve the authorized client Issue: gh-4921
This commit is contained in:
parent
158b8aa6d5
commit
ac78258847
|
@ -27,6 +27,7 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|||
import org.springframework.security.core.context.SecurityContext;
|
||||
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
||||
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
|
||||
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
|
||||
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
|
||||
|
@ -86,6 +87,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|||
|
||||
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
|
||||
|
||||
private boolean defaultOAuth2AuthorizedClient;
|
||||
|
||||
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
|
||||
new WebClientReactiveClientCredentialsTokenResponseClient();
|
||||
|
||||
|
@ -174,6 +177,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|||
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
|
||||
}
|
||||
|
||||
/**
|
||||
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
|
||||
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
|
||||
* resolved from the current Authentication.
|
||||
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
|
||||
* Default is false.
|
||||
*/
|
||||
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
|
||||
this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
|
||||
* client_credentials grant.
|
||||
|
@ -216,14 +230,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|||
if (this.authorizedClientRepository == null) {
|
||||
return Mono.empty();
|
||||
}
|
||||
String clientRegistrationId = (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
|
||||
if (clientRegistrationId == null) {
|
||||
return Mono.empty();
|
||||
}
|
||||
|
||||
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
||||
return currentAuthentication()
|
||||
.flatMap(principal -> loadAuthorizedClient(clientRegistrationId, exchange, principal)
|
||||
);
|
||||
.flatMap(principal -> clientRegistrationId(request, principal)
|
||||
.flatMap(clientRegistrationId -> loadAuthorizedClient(clientRegistrationId, exchange, principal))
|
||||
);
|
||||
}
|
||||
|
||||
private Mono<String> clientRegistrationId(ClientRequest request, Authentication authentication) {
|
||||
return Mono.justOrEmpty(request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME))
|
||||
.cast(String.class)
|
||||
.switchIfEmpty(clientRegistrationId(authentication));
|
||||
}
|
||||
|
||||
private Mono<String> clientRegistrationId(Authentication authentication) {
|
||||
return Mono.justOrEmpty(authentication)
|
||||
.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
|
||||
.cast(OAuth2AuthenticationToken.class)
|
||||
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
|
||||
|
|
|
@ -34,8 +34,10 @@ import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
|
|||
import org.springframework.http.server.reactive.ServerHttpRequest;
|
||||
import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.core.authority.AuthorityUtils;
|
||||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
||||
|
@ -43,6 +45,8 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori
|
|||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
|
||||
import org.springframework.security.oauth2.core.user.OAuth2User;
|
||||
import org.springframework.web.reactive.function.BodyInserter;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
@ -51,6 +55,7 @@ import java.net.URI;
|
|||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -60,6 +65,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
|||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyZeroInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.http.HttpMethod.GET;
|
||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
|
||||
|
@ -293,6 +299,59 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|||
assertThat(getBody(request0)).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() {
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
this.function.setDefaultOAuth2AuthorizedClient(true);
|
||||
|
||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||
"principalName", this.accessToken, refreshToken);
|
||||
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
|
||||
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||
.build();
|
||||
|
||||
OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections
|
||||
.singletonMap("user", "rob"), "user");
|
||||
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id");
|
||||
this.function
|
||||
.filter(request, this.exchange)
|
||||
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
|
||||
.block();
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(1);
|
||||
|
||||
ClientRequest request0 = requests.get(0);
|
||||
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
|
||||
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
|
||||
assertThat(request0.method()).isEqualTo(HttpMethod.GET);
|
||||
assertThat(getBody(request0)).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
|
||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
|
||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||
.build();
|
||||
|
||||
OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections
|
||||
.singletonMap("user", "rob"), "user");
|
||||
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id");
|
||||
|
||||
this.function
|
||||
.filter(request, this.exchange)
|
||||
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
|
||||
.block();
|
||||
|
||||
List<ClientRequest> requests = this.exchange.getRequests();
|
||||
assertThat(requests).hasSize(1);
|
||||
|
||||
verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||
}
|
||||
|
||||
private static String getBody(ClientRequest request) {
|
||||
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
|
||||
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
|
||||
|
|
Loading…
Reference in New Issue