DefaultReactiveOAuth2AuthorizedClientManager defaults ServerWebExchange
Fixes gh-7390
This commit is contained in:
parent
96d44cd4b7
commit
dcdeab596d
|
@ -70,35 +70,52 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
|
|||
|
||||
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
|
||||
Authentication principal = authorizeRequest.getPrincipal();
|
||||
|
||||
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
|
||||
Assert.notNull(serverWebExchange, "serverWebExchange cannot be null");
|
||||
|
||||
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
|
||||
.switchIfEmpty(Mono.defer(() ->
|
||||
this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
|
||||
.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
|
||||
.flatMap(authorizedClient -> {
|
||||
// Re-authorize
|
||||
return authorizationContext(authorizeRequest, authorizedClient)
|
||||
.flatMap(this.authorizedClientProvider::authorize)
|
||||
.doOnNext(reauthorizedClient ->
|
||||
this.authorizedClientRepository.saveAuthorizedClient(
|
||||
reauthorizedClient, principal, serverWebExchange))
|
||||
.flatMap(reauthorizedClient -> saveAuthorizedClient(reauthorizedClient, principal, serverWebExchange))
|
||||
// Default to the existing authorizedClient if the client was not re-authorized
|
||||
.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
|
||||
authorizeRequest.getAuthorizedClient() : authorizedClient);
|
||||
})
|
||||
.switchIfEmpty(Mono.defer(() ->
|
||||
.switchIfEmpty(Mono.deferWithContext(context ->
|
||||
// Authorize
|
||||
this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
|
||||
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
|
||||
"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
|
||||
.flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration))
|
||||
.flatMap(this.authorizedClientProvider::authorize)
|
||||
.doOnNext(authorizedClient ->
|
||||
this.authorizedClientRepository.saveAuthorizedClient(
|
||||
authorizedClient, principal, serverWebExchange))
|
||||
));
|
||||
.flatMap(authorizedClient -> saveAuthorizedClient(authorizedClient, principal, serverWebExchange))
|
||||
.subscriberContext(context)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange serverWebExchange) {
|
||||
return Mono.justOrEmpty(serverWebExchange)
|
||||
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
|
||||
.flatMap(exchange -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange));
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizedClient> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, ServerWebExchange serverWebExchange) {
|
||||
return Mono.justOrEmpty(serverWebExchange)
|
||||
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
|
||||
.map(exchange -> {
|
||||
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange);
|
||||
return authorizedClient;
|
||||
})
|
||||
.defaultIfEmpty(authorizedClient);
|
||||
}
|
||||
|
||||
private static Mono<ServerWebExchange> currentServerWebExchange() {
|
||||
return Mono.subscriberContext()
|
||||
.filter(c -> c.hasKey(ServerWebExchange.class))
|
||||
.map(c -> c.get(ServerWebExchange.class));
|
||||
}
|
||||
|
||||
private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest,
|
||||
|
@ -158,15 +175,20 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
|
|||
|
||||
@Override
|
||||
public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
|
||||
Map<String, Object> contextAttributes = Collections.emptyMap();
|
||||
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
|
||||
String scope = serverWebExchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
|
||||
if (StringUtils.hasText(scope)) {
|
||||
contextAttributes = new HashMap<>();
|
||||
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
|
||||
StringUtils.delimitedListToStringArray(scope, " "));
|
||||
}
|
||||
return Mono.just(contextAttributes);
|
||||
return Mono.justOrEmpty(serverWebExchange)
|
||||
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
|
||||
.flatMap(exchange -> {
|
||||
Map<String, Object> contextAttributes = Collections.emptyMap();
|
||||
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
|
||||
if (StringUtils.hasText(scope)) {
|
||||
contextAttributes = new HashMap<>();
|
||||
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
|
||||
StringUtils.delimitedListToStringArray(scope, " "));
|
||||
}
|
||||
return Mono.just(contextAttributes);
|
||||
})
|
||||
.defaultIfEmpty(Collections.emptyMap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,9 +34,9 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori
|
|||
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
|
||||
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
|
@ -64,6 +64,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
private Authentication principal;
|
||||
private OAuth2AuthorizedClient authorizedClient;
|
||||
private MockServerWebExchange serverWebExchange;
|
||||
private Context context;
|
||||
private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -75,6 +76,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class);
|
||||
when(this.authorizedClientRepository.loadAuthorizedClient(
|
||||
anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
|
||||
when(this.authorizedClientRepository.saveAuthorizedClient(
|
||||
any(OAuth2AuthorizedClient.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
|
||||
this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
|
||||
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty());
|
||||
this.contextAttributesMapper = mock(Function.class);
|
||||
|
@ -88,6 +91,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
|
||||
TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
|
||||
this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
|
||||
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
|
||||
this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class);
|
||||
}
|
||||
|
||||
|
@ -119,16 +123,6 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
.hasMessage("contextAttributesMapper cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() {
|
||||
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
|
||||
.principal(this.principal)
|
||||
.build();
|
||||
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("serverWebExchange cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
|
||||
assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block())
|
||||
|
@ -140,9 +134,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
|
||||
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id")
|
||||
.principal(this.principal)
|
||||
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
|
||||
.build();
|
||||
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
|
||||
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
|
||||
}
|
||||
|
@ -155,9 +148,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
|
||||
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
|
||||
.principal(this.principal)
|
||||
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
|
||||
.build();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
|
||||
.subscriberContext(this.context).block();
|
||||
|
||||
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
|
||||
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
|
||||
|
@ -168,8 +161,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
|
||||
|
||||
assertThat(authorizedClient).isNull();
|
||||
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
|
||||
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
|
||||
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -177,15 +169,14 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
|
||||
when(this.clientRegistrationRepository.findByRegistrationId(
|
||||
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
|
||||
|
||||
when(this.authorizedClientProvider.authorize(
|
||||
any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
|
||||
|
||||
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
|
||||
.principal(this.principal)
|
||||
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
|
||||
.build();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
|
||||
.subscriberContext(this.context).block();
|
||||
|
||||
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
|
||||
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
|
||||
|
@ -200,6 +191,31 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authorizeWhenNotAuthorizedAndSupportedProviderAndExchangeUnavailableThenAuthorizedButNotSaved() {
|
||||
when(this.clientRegistrationRepository.findByRegistrationId(
|
||||
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
|
||||
|
||||
when(this.authorizedClientProvider.authorize(
|
||||
any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
|
||||
|
||||
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
|
||||
.principal(this.principal)
|
||||
.build();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
|
||||
|
||||
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
|
||||
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
|
||||
|
||||
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
|
||||
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
|
||||
assertThat(authorizationContext.getAuthorizedClient()).isNull();
|
||||
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
|
||||
|
||||
assertThat(authorizedClient).isSameAs(this.authorizedClient);
|
||||
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Test
|
||||
public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
|
||||
|
@ -216,9 +232,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
|
||||
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
|
||||
.principal(this.principal)
|
||||
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
|
||||
.build();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
|
||||
.subscriberContext(this.context).block();
|
||||
|
||||
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
|
||||
verify(this.contextAttributesMapper).apply(any());
|
||||
|
@ -241,21 +257,18 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
|
||||
|
||||
// Set custom contextAttributesMapper capable of mapping the form parameters
|
||||
this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> {
|
||||
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
|
||||
return Mono.just(serverWebExchange)
|
||||
this.authorizedClientManager.setContextAttributesMapper(authorizeRequest ->
|
||||
currentServerWebExchange()
|
||||
.flatMap(ServerWebExchange::getFormData)
|
||||
.map(formData -> {
|
||||
Map<String, Object> contextAttributes = new HashMap<>();
|
||||
String username = formData.getFirst(OAuth2ParameterNames.USERNAME);
|
||||
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
|
||||
String password = formData.getFirst(OAuth2ParameterNames.PASSWORD);
|
||||
if (StringUtils.hasText(username) && StringUtils.hasText(password)) {
|
||||
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
|
||||
contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
|
||||
}
|
||||
contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
|
||||
return contextAttributes;
|
||||
});
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
this.serverWebExchange = MockServerWebExchange.builder(
|
||||
MockServerHttpRequest
|
||||
|
@ -263,12 +276,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
|
||||
.body("username=username&password=password"))
|
||||
.build();
|
||||
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
|
||||
|
||||
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
|
||||
.principal(this.principal)
|
||||
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
|
||||
.build();
|
||||
this.authorizedClientManager.authorize(authorizeRequest).block();
|
||||
this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block();
|
||||
|
||||
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
|
||||
|
||||
|
@ -284,9 +297,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
|
||||
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
|
||||
.principal(this.principal)
|
||||
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
|
||||
.build();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
|
||||
.subscriberContext(this.context).block();
|
||||
|
||||
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
|
||||
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
|
||||
|
@ -297,8 +310,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
|
||||
|
||||
assertThat(authorizedClient).isSameAs(this.authorizedClient);
|
||||
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
|
||||
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
|
||||
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -312,9 +324,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
|
||||
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
|
||||
.principal(this.principal)
|
||||
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
|
||||
.build();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
|
||||
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
|
||||
.subscriberContext(this.context).block();
|
||||
|
||||
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
|
||||
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
|
||||
|
@ -346,12 +358,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
.get("/")
|
||||
.queryParam(OAuth2ParameterNames.SCOPE, "read write"))
|
||||
.build();
|
||||
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
|
||||
|
||||
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
|
||||
.principal(this.principal)
|
||||
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
|
||||
.build();
|
||||
this.authorizedClientManager.authorize(reauthorizeRequest).block();
|
||||
this.authorizedClientManager.authorize(reauthorizeRequest).subscriberContext(this.context).block();
|
||||
|
||||
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
|
||||
|
||||
|
@ -359,4 +371,10 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
|
|||
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
|
||||
assertThat(requestScopeAttribute).contains("read", "write");
|
||||
}
|
||||
|
||||
private Mono<ServerWebExchange> currentServerWebExchange() {
|
||||
return Mono.subscriberContext()
|
||||
.filter(c -> c.hasKey(ServerWebExchange.class))
|
||||
.map(c -> c.get(ServerWebExchange.class));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue