From ffb5a3a0d426465cd2e34d90d72fe50e22913e21 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 4 Feb 2020 15:56:21 -0700 Subject: [PATCH] Add oauth2Client WebTestClient Support Fixes gh-7910 --- .../server/SecurityMockServerConfigurers.java | 185 +++++++++++++++--- ...ockServerConfigurersOAuth2ClientTests.java | 166 ++++++++++++++++ 2 files changed, 325 insertions(+), 26 deletions(-) create mode 100644 test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java diff --git a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java index fa49533a4e..cd20a2db11 100644 --- a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java +++ b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java @@ -60,6 +60,7 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter; +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors; import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.test.web.reactive.server.MockServerConfigurer; @@ -182,6 +183,39 @@ public class SecurityMockServerConfigurers { return new OidcLoginMutator(accessToken); } + /** + * Updates the ServerWebExchange to establish a {@link OAuth2AuthorizedClient} in the session. + * All details are declarative and do not require the corresponding OAuth 2.0 tokens to be valid. + * + *

+ * The support works by associating the authorized client to the ServerWebExchange + * via the {@link WebSessionServerOAuth2AuthorizedClientRepository} + *

+ * + * @return the {@link OAuth2ClientMutator} to further configure or use + * @since 5.3 + */ + public static OAuth2ClientMutator mockOAuth2Client() { + return new OAuth2ClientMutator(); + } + + /** + * Updates the ServerWebExchange to establish a {@link OAuth2AuthorizedClient} in the session. + * All details are declarative and do not require the corresponding OAuth 2.0 tokens to be valid. + * + *

+ * The support works by associating the authorized client to the ServerWebExchange + * via the {@link WebSessionServerOAuth2AuthorizedClientRepository} + *

+ * + * @param registrationId The registration id associated with the {@link OAuth2AuthorizedClient} + * @return the {@link OAuth2ClientMutator} to further configure or use + * @since 5.3 + */ + public static OAuth2ClientMutator mockOAuth2Client(String registrationId) { + return new OAuth2ClientMutator(registrationId); + } + public static CsrfMutator csrf() { return new CsrfMutator(); } @@ -591,12 +625,19 @@ public class SecurityMockServerConfigurers { @Override public void beforeServerCreated(WebHttpHandlerBuilder builder) { OAuth2AuthenticationToken token = getToken(); - builder.filters(addAuthorizedClientFilter(token)); + mockOAuth2Client() + .accessToken(this.accessToken) + .clientRegistration(this.clientRegistration) + .beforeServerCreated(builder); mockAuthentication(getToken()).beforeServerCreated(builder); } @Override public void afterConfigureAdded(WebTestClient.MockServerSpec serverSpec) { + mockOAuth2Client() + .accessToken(this.accessToken) + .clientRegistration(this.clientRegistration) + .afterConfigureAdded(serverSpec); mockAuthentication(getToken()).afterConfigureAdded(serverSpec); } @@ -606,26 +647,18 @@ public class SecurityMockServerConfigurers { @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector) { OAuth2AuthenticationToken token = getToken(); - httpHandlerBuilder.filters(addAuthorizedClientFilter(token)); + mockOAuth2Client() + .accessToken(this.accessToken) + .clientRegistration(this.clientRegistration) + .afterConfigurerAdded(builder, httpHandlerBuilder, connector); mockAuthentication(token).afterConfigurerAdded(builder, httpHandlerBuilder, connector); } - private Consumer> addAuthorizedClientFilter(OAuth2AuthenticationToken token) { - OAuth2AuthorizedClient client = getClient(); - return filters -> filters.add(0, (exchange, chain) -> - this.authorizedClientRepository.saveAuthorizedClient(client, token, exchange) - .then(chain.filter(exchange))); - } - private OAuth2AuthenticationToken getToken() { OAuth2User oauth2User = this.oauth2User.get(); return new OAuth2AuthenticationToken(oauth2User, oauth2User.getAuthorities(), this.clientRegistration.getRegistrationId()); } - private OAuth2AuthorizedClient getClient() { - return new OAuth2AuthorizedClient(this.clientRegistration, getToken().getName(), this.accessToken); - } - private ClientRegistration.Builder clientRegistrationBuilder() { return ClientRegistration.withRegistrationId("test") .authorizationGrantType(AuthorizationGrantType.PASSWORD) @@ -760,12 +793,19 @@ public class SecurityMockServerConfigurers { @Override public void beforeServerCreated(WebHttpHandlerBuilder builder) { OAuth2AuthenticationToken token = getToken(); - builder.filters(addAuthorizedClientFilter(token)); + mockOAuth2Client() + .accessToken(this.accessToken) + .clientRegistration(this.clientRegistration) + .beforeServerCreated(builder); mockAuthentication(getToken()).beforeServerCreated(builder); } @Override public void afterConfigureAdded(WebTestClient.MockServerSpec serverSpec) { + mockOAuth2Client() + .accessToken(this.accessToken) + .clientRegistration(this.clientRegistration) + .afterConfigureAdded(serverSpec); mockAuthentication(getToken()).afterConfigureAdded(serverSpec); } @@ -775,17 +815,13 @@ public class SecurityMockServerConfigurers { @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector) { OAuth2AuthenticationToken token = getToken(); - httpHandlerBuilder.filters(addAuthorizedClientFilter(token)); + mockOAuth2Client() + .accessToken(this.accessToken) + .clientRegistration(this.clientRegistration) + .afterConfigurerAdded(builder, httpHandlerBuilder, connector); mockAuthentication(token).afterConfigurerAdded(builder, httpHandlerBuilder, connector); } - private Consumer> addAuthorizedClientFilter(OAuth2AuthenticationToken token) { - OAuth2AuthorizedClient client = getClient(); - return filters -> filters.add(0, (exchange, chain) -> - authorizedClientRepository.saveAuthorizedClient(client, token, exchange) - .then(chain.filter(exchange))); - } - private ClientRegistration.Builder clientRegistrationBuilder() { return ClientRegistration.withRegistrationId("test") .authorizationGrantType(AuthorizationGrantType.PASSWORD) @@ -798,10 +834,6 @@ public class SecurityMockServerConfigurers { return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), this.clientRegistration.getRegistrationId()); } - private OAuth2AuthorizedClient getClient() { - return new OAuth2AuthorizedClient(this.clientRegistration, getToken().getName(), this.accessToken); - } - private Collection getAuthorities() { if (this.authorities == null) { Set authorities = new LinkedHashSet<>(); @@ -831,4 +863,105 @@ public class SecurityMockServerConfigurers { return new DefaultOidcUser(getAuthorities(), getOidcIdToken(), this.userInfo); } } + + /** + * @author Josh Cummings + * @since 5.3 + */ + public final static class OAuth2ClientMutator implements WebTestClientConfigurer, MockServerConfigurer { + private String registrationId = "test"; + private ClientRegistration clientRegistration; + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token", null, null, Collections.singleton("user")); + + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository = + new WebSessionServerOAuth2AuthorizedClientRepository(); + + private OAuth2ClientMutator() { + } + + private OAuth2ClientMutator(String registrationId) { + this.registrationId = registrationId; + clientRegistration(c -> {}); + } + + /** + * Use this {@link ClientRegistration} + * + * @param clientRegistration + * @return the {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} for further configuration + */ + public OAuth2ClientMutator clientRegistration(ClientRegistration clientRegistration) { + this.clientRegistration = clientRegistration; + return this; + } + + /** + * Use this {@link Consumer} to configure a {@link ClientRegistration} + * + * @param clientRegistrationConfigurer the {@link ClientRegistration} configurer + * @return the {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} for further configuration + */ + public OAuth2ClientMutator clientRegistration + (Consumer clientRegistrationConfigurer) { + + ClientRegistration.Builder builder = clientRegistrationBuilder(); + clientRegistrationConfigurer.accept(builder); + this.clientRegistration = builder.build(); + return this; + } + + /** + * Use this {@link OAuth2AccessToken} + * + * @param accessToken the {@link OAuth2AccessToken} to use + * @return the {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} for further configuration + */ + public OAuth2ClientMutator accessToken(OAuth2AccessToken accessToken) { + this.accessToken = accessToken; + return this; + } + + + @Override + public void beforeServerCreated(WebHttpHandlerBuilder builder) { + builder.filters(addAuthorizedClientFilter()); + } + + @Override + public void afterConfigureAdded(WebTestClient.MockServerSpec serverSpec) { + + } + + @Override + public void afterConfigurerAdded( + WebTestClient.Builder builder, + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, + @Nullable ClientHttpConnector connector) { + httpHandlerBuilder.filters(addAuthorizedClientFilter()); + } + + private Consumer> addAuthorizedClientFilter() { + OAuth2AuthorizedClient client = getClient(); + return filters -> filters.add(0, (exchange, chain) -> + authorizedClientRepository.saveAuthorizedClient(client, null, exchange) + .then(chain.filter(exchange))); + } + + private OAuth2AuthorizedClient getClient() { + if (this.clientRegistration == null) { + throw new IllegalArgumentException("Please specify a ClientRegistration via one " + + "of the clientRegistration methods"); + } + return new OAuth2AuthorizedClient(this.clientRegistration, "test-subject", this.accessToken); + } + + private ClientRegistration.Builder clientRegistrationBuilder() { + return ClientRegistration.withRegistrationId(this.registrationId) + .authorizationGrantType(AuthorizationGrantType.PASSWORD) + .clientId("test-client") + .clientSecret("test-secret") + .tokenUri("https://idp.example.org/oauth/token"); + } + } } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java new file mode 100644 index 0000000000..f15ee4eb95 --- /dev/null +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.test.web.reactive.server; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.reactive.DispatcherHandler; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOAuth2Client; +import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; + +@RunWith(MockitoJUnitRunner.class) +public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMockServerConfigurersTests { + private OAuth2LoginController controller = new OAuth2LoginController(); + + @Mock + private ReactiveClientRegistrationRepository clientRegistrationRepository; + + private WebTestClient client; + + @Before + public void setup() { + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = + new WebSessionServerOAuth2AuthorizedClientRepository(); + + this.client = WebTestClient + .bindToController(this.controller) + .argumentResolvers(c -> c.addCustomResolver( + new OAuth2AuthorizedClientArgumentResolver + (this.clientRegistrationRepository, authorizedClientRepository))) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) + .apply(springSecurity()) + .configureClient() + .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) + .build(); + } + + @Test + public void oauth2ClientWhenUsingDefaultsThenException() + throws Exception { + + WebHttpHandlerBuilder builder = WebHttpHandlerBuilder.webHandler(new DispatcherHandler()); + assertThatCode(() -> mockOAuth2Client().beforeServerCreated(builder)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ClientRegistration"); + } + + @Test + public void oauth2ClientWhenUsingRegistrationIdThenProducesAuthorizedClient() + throws Exception { + + this.client.mutateWith(mockOAuth2Client("registration-id")) + .get().uri("/client") + .exchange() + .expectStatus().isOk(); + + OAuth2AuthorizedClient client = this.controller.authorizedClient; + assertThat(client).isNotNull(); + assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id"); + assertThat(client.getAccessToken().getTokenValue()).isEqualTo("access-token"); + assertThat(client.getRefreshToken()).isNull(); + } + + @Test + public void oauth2ClientWhenClientRegistrationThenUses() + throws Exception { + + ClientRegistration clientRegistration = clientRegistration() + .registrationId("registration-id").clientId("client-id").build(); + this.client.mutateWith(mockOAuth2Client().clientRegistration(clientRegistration)) + .get().uri("/client") + .exchange() + .expectStatus().isOk(); + + OAuth2AuthorizedClient client = this.controller.authorizedClient; + assertThat(client).isNotNull(); + assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id"); + assertThat(client.getAccessToken().getTokenValue()).isEqualTo("access-token"); + assertThat(client.getRefreshToken()).isNull(); + } + + @Test + public void oauth2ClientWhenClientRegistrationConsumerThenUses() + throws Exception { + + this.client.mutateWith(mockOAuth2Client("registration-id") + .clientRegistration(c -> c.clientId("client-id"))) + .get().uri("/client") + .exchange() + .expectStatus().isOk(); + + OAuth2AuthorizedClient client = this.controller.authorizedClient; + assertThat(client).isNotNull(); + assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id"); + assertThat(client.getClientRegistration().getClientId()).isEqualTo("client-id"); + assertThat(client.getAccessToken().getTokenValue()).isEqualTo("access-token"); + assertThat(client.getRefreshToken()).isNull(); + } + + @Test + public void oauth2ClientWhenAccessTokenThenUses() + throws Exception { + + OAuth2AccessToken accessToken = noScopes(); + this.client.mutateWith(mockOAuth2Client("registration-id") + .accessToken(accessToken)) + .get().uri("/client") + .exchange() + .expectStatus().isOk(); + + OAuth2AuthorizedClient client = this.controller.authorizedClient; + assertThat(client).isNotNull(); + assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id"); + assertThat(client.getAccessToken().getTokenValue()).isEqualTo("no-scopes"); + assertThat(client.getRefreshToken()).isNull(); + } + + @RestController + static class OAuth2LoginController { + volatile OAuth2AuthorizedClient authorizedClient; + + @GetMapping("/client") + String authorizedClient + (@RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) { + this.authorizedClient = authorizedClient; + return authorizedClient.getPrincipalName(); + } + } +}