DefaultReactiveOAuth2AuthorizedClientManager defaults ServerWebExchange

Fixes gh-7390
This commit is contained in:
Joe Grandja 2019-09-08 06:30:38 -04:00
parent 96d44cd4b7
commit dcdeab596d
2 changed files with 101 additions and 61 deletions

View File

@ -70,35 +70,52 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
Authentication principal = authorizeRequest.getPrincipal();
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
Assert.notNull(serverWebExchange, "serverWebExchange cannot be null");
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
.switchIfEmpty(Mono.defer(() ->
this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
.flatMap(authorizedClient -> {
// Re-authorize
return authorizationContext(authorizeRequest, authorizedClient)
.flatMap(this.authorizedClientProvider::authorize)
.doOnNext(reauthorizedClient ->
this.authorizedClientRepository.saveAuthorizedClient(
reauthorizedClient, principal, serverWebExchange))
.flatMap(reauthorizedClient -> saveAuthorizedClient(reauthorizedClient, principal, serverWebExchange))
// Default to the existing authorizedClient if the client was not re-authorized
.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
authorizeRequest.getAuthorizedClient() : authorizedClient);
})
.switchIfEmpty(Mono.defer(() ->
.switchIfEmpty(Mono.deferWithContext(context ->
// Authorize
this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
.flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration))
.flatMap(this.authorizedClientProvider::authorize)
.doOnNext(authorizedClient ->
this.authorizedClientRepository.saveAuthorizedClient(
authorizedClient, principal, serverWebExchange))
));
.flatMap(authorizedClient -> saveAuthorizedClient(authorizedClient, principal, serverWebExchange))
.subscriberContext(context)
)
);
}
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange serverWebExchange) {
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
.flatMap(exchange -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange));
}
private Mono<OAuth2AuthorizedClient> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, ServerWebExchange serverWebExchange) {
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
.map(exchange -> {
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange);
return authorizedClient;
})
.defaultIfEmpty(authorizedClient);
}
private static Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}
private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest,
@ -158,15 +175,20 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
@Override
public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
Map<String, Object> contextAttributes = Collections.emptyMap();
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
String scope = serverWebExchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
return Mono.just(contextAttributes);
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
.flatMap(exchange -> {
Map<String, Object> contextAttributes = Collections.emptyMap();
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
return Mono.just(contextAttributes);
})
.defaultIfEmpty(Collections.emptyMap());
}
}
}

View File

@ -34,9 +34,9 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import java.util.Collections;
import java.util.HashMap;
@ -64,6 +64,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
private Authentication principal;
private OAuth2AuthorizedClient authorizedClient;
private MockServerWebExchange serverWebExchange;
private Context context;
private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
@SuppressWarnings("unchecked")
@ -75,6 +76,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class);
when(this.authorizedClientRepository.loadAuthorizedClient(
anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
when(this.authorizedClientRepository.saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty());
this.contextAttributesMapper = mock(Function.class);
@ -88,6 +91,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class);
}
@ -119,16 +123,6 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
.hasMessage("contextAttributesMapper cannot be null");
}
@Test
public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("serverWebExchange cannot be null");
}
@Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block())
@ -140,9 +134,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id")
.principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build();
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
}
@ -155,9 +148,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -168,8 +161,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isNull();
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
}
@SuppressWarnings("unchecked")
@ -177,15 +169,14 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
when(this.authorizedClientProvider.authorize(
any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -200,6 +191,31 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange));
}
@Test
public void authorizeWhenNotAuthorizedAndSupportedProviderAndExchangeUnavailableThenAuthorizedButNotSaved() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
when(this.authorizedClientProvider.authorize(
any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
}
@SuppressWarnings("unchecked")
@Test
public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
@ -216,9 +232,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(any());
@ -241,21 +257,18 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
// Set custom contextAttributesMapper capable of mapping the form parameters
this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> {
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
return Mono.just(serverWebExchange)
this.authorizedClientManager.setContextAttributesMapper(authorizeRequest ->
currentServerWebExchange()
.flatMap(ServerWebExchange::getFormData)
.map(formData -> {
Map<String, Object> contextAttributes = new HashMap<>();
String username = formData.getFirst(OAuth2ParameterNames.USERNAME);
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
String password = formData.getFirst(OAuth2ParameterNames.PASSWORD);
if (StringUtils.hasText(username) && StringUtils.hasText(password)) {
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
}
contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
return contextAttributes;
});
});
})
);
this.serverWebExchange = MockServerWebExchange.builder(
MockServerHttpRequest
@ -263,12 +276,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body("username=username&password=password"))
.build();
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build();
this.authorizedClientManager.authorize(authorizeRequest).block();
this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
@ -284,9 +297,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@ -297,8 +310,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
}
@SuppressWarnings("unchecked")
@ -312,9 +324,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@ -346,12 +358,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
.get("/")
.queryParam(OAuth2ParameterNames.SCOPE, "read write"))
.build();
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build();
this.authorizedClientManager.authorize(reauthorizeRequest).block();
this.authorizedClientManager.authorize(reauthorizeRequest).subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
@ -359,4 +371,10 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
assertThat(requestScopeAttribute).contains("read", "write");
}
private Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}
}