From dcdeab596d420d45f7d43133dfaec070d57d2cd9 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Sun, 8 Sep 2019 06:30:38 -0400 Subject: [PATCH] DefaultReactiveOAuth2AuthorizedClientManager defaults ServerWebExchange Fixes gh-7390 --- ...ReactiveOAuth2AuthorizedClientManager.java | 62 +++++++---- ...iveOAuth2AuthorizedClientManagerTests.java | 100 +++++++++++------- 2 files changed, 101 insertions(+), 61 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java index a04b8ee04c..399839c138 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java @@ -70,35 +70,52 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React String clientRegistrationId = authorizeRequest.getClientRegistrationId(); Authentication principal = authorizeRequest.getPrincipal(); - ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); - Assert.notNull(serverWebExchange, "serverWebExchange cannot be null"); return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) - .switchIfEmpty(Mono.defer(() -> - this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange))) + .switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange))) .flatMap(authorizedClient -> { // Re-authorize return authorizationContext(authorizeRequest, authorizedClient) .flatMap(this.authorizedClientProvider::authorize) - .doOnNext(reauthorizedClient -> - this.authorizedClientRepository.saveAuthorizedClient( - reauthorizedClient, principal, serverWebExchange)) + .flatMap(reauthorizedClient -> saveAuthorizedClient(reauthorizedClient, principal, serverWebExchange)) // Default to the existing authorizedClient if the client was not re-authorized .defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ? authorizeRequest.getAuthorizedClient() : authorizedClient); }) - .switchIfEmpty(Mono.defer(() -> + .switchIfEmpty(Mono.deferWithContext(context -> // Authorize this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) .flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration)) .flatMap(this.authorizedClientProvider::authorize) - .doOnNext(authorizedClient -> - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, principal, serverWebExchange)) - )); + .flatMap(authorizedClient -> saveAuthorizedClient(authorizedClient, principal, serverWebExchange)) + .subscriberContext(context) + ) + ); + } + + private Mono loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange serverWebExchange) { + return Mono.justOrEmpty(serverWebExchange) + .switchIfEmpty(Mono.defer(() -> currentServerWebExchange())) + .flatMap(exchange -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)); + } + + private Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, ServerWebExchange serverWebExchange) { + return Mono.justOrEmpty(serverWebExchange) + .switchIfEmpty(Mono.defer(() -> currentServerWebExchange())) + .map(exchange -> { + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange); + return authorizedClient; + }) + .defaultIfEmpty(authorizedClient); + } + + private static Mono currentServerWebExchange() { + return Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); } private Mono authorizationContext(OAuth2AuthorizeRequest authorizeRequest, @@ -158,15 +175,20 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React @Override public Mono> apply(OAuth2AuthorizeRequest authorizeRequest) { - Map contextAttributes = Collections.emptyMap(); ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); - String scope = serverWebExchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); - if (StringUtils.hasText(scope)) { - contextAttributes = new HashMap<>(); - contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, - StringUtils.delimitedListToStringArray(scope, " ")); - } - return Mono.just(contextAttributes); + return Mono.justOrEmpty(serverWebExchange) + .switchIfEmpty(Mono.defer(() -> currentServerWebExchange())) + .flatMap(exchange -> { + Map contextAttributes = Collections.emptyMap(); + String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); + if (StringUtils.hasText(scope)) { + contextAttributes = new HashMap<>(); + contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, + StringUtils.delimitedListToStringArray(scope, " ")); + } + return Mono.just(contextAttributes); + }) + .defaultIfEmpty(Collections.emptyMap()); } } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java index ade1fb4e4a..50e60ecf28 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java @@ -34,9 +34,9 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; +import reactor.util.context.Context; import java.util.Collections; import java.util.HashMap; @@ -64,6 +64,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { private Authentication principal; private OAuth2AuthorizedClient authorizedClient; private MockServerWebExchange serverWebExchange; + private Context context; private ArgumentCaptor authorizationContextCaptor; @SuppressWarnings("unchecked") @@ -75,6 +76,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class); when(this.authorizedClientRepository.loadAuthorizedClient( anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty()); + when(this.authorizedClientRepository.saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty()); this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty()); this.contextAttributesMapper = mock(Function.class); @@ -88,6 +91,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); + this.context = Context.of(ServerWebExchange.class, this.serverWebExchange); this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class); } @@ -119,16 +123,6 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { .hasMessage("contextAttributesMapper cannot be null"); } - @Test - public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("serverWebExchange cannot be null"); - } - @Test public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block()) @@ -140,9 +134,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id") .principal(this.principal) - .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) .build(); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); } @@ -155,9 +148,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) - .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) .build(); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest) + .subscriberContext(this.context).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); @@ -168,8 +161,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isNull(); - verify(this.authorizedClientRepository, never()).saveAuthorizedClient( - any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange)); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); } @SuppressWarnings("unchecked") @@ -177,15 +169,14 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { when(this.clientRegistrationRepository.findByRegistrationId( eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - when(this.authorizedClientProvider.authorize( any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) - .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) .build(); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest) + .subscriberContext(this.context).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); @@ -200,6 +191,31 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange)); } + @Test + public void authorizeWhenNotAuthorizedAndSupportedProviderAndExchangeUnavailableThenAuthorizedButNotSaved() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientProvider.authorize( + any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); + } + @SuppressWarnings("unchecked") @Test public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { @@ -216,9 +232,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) - .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) .build(); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest) + .subscriberContext(this.context).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(any()); @@ -241,21 +257,18 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); // Set custom contextAttributesMapper capable of mapping the form parameters - this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> { - ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); - return Mono.just(serverWebExchange) + this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> + currentServerWebExchange() .flatMap(ServerWebExchange::getFormData) .map(formData -> { Map contextAttributes = new HashMap<>(); String username = formData.getFirst(OAuth2ParameterNames.USERNAME); + contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username); String password = formData.getFirst(OAuth2ParameterNames.PASSWORD); - if (StringUtils.hasText(username) && StringUtils.hasText(password)) { - contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username); - contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password); - } + contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password); return contextAttributes; - }); - }); + }) + ); this.serverWebExchange = MockServerWebExchange.builder( MockServerHttpRequest @@ -263,12 +276,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { .contentType(MediaType.APPLICATION_FORM_URLENCODED) .body("username=username&password=password")) .build(); + this.context = Context.of(ServerWebExchange.class, this.serverWebExchange); OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) .principal(this.principal) - .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) .build(); - this.authorizedClientManager.authorize(authorizeRequest).block(); + this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -284,9 +297,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) .build(); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest) + .subscriberContext(this.context).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); @@ -297,8 +310,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isSameAs(this.authorizedClient); - verify(this.authorizedClientRepository, never()).saveAuthorizedClient( - any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange)); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); } @SuppressWarnings("unchecked") @@ -312,9 +324,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) .build(); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest) + .subscriberContext(this.context).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); @@ -346,12 +358,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { .get("/") .queryParam(OAuth2ParameterNames.SCOPE, "read write")) .build(); + this.context = Context.of(ServerWebExchange.class, this.serverWebExchange); OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) .build(); - this.authorizedClientManager.authorize(reauthorizeRequest).block(); + this.authorizedClientManager.authorize(reauthorizeRequest).subscriberContext(this.context).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -359,4 +371,10 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); assertThat(requestScopeAttribute).contains("read", "write"); } + + private Mono currentServerWebExchange() { + return Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); + } }