Inject TestOAuth2AuthorizedClientRepository

Fixes gh-8603
This commit is contained in:
Josh Cummings 2020-05-27 14:33:02 -06:00
parent d014d29199
commit 900f551890
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
9 changed files with 376 additions and 30 deletions

View File

@ -11,6 +11,7 @@ dependencies {
optional project(':spring-security-oauth2-jose')
optional project(':spring-security-oauth2-resource-server')
optional 'io.projectreactor:reactor-core'
optional 'org.springframework:spring-webmvc'
optional 'org.springframework:spring-webflux'
provided 'javax.servlet:javax.servlet-api'

View File

@ -32,6 +32,7 @@ import java.util.stream.Collectors;
import com.nimbusds.oauth2.sdk.util.StringUtils;
import reactor.core.publisher.Mono;
import org.springframework.context.ApplicationContext;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.client.reactive.ClientHttpConnector;
import org.springframework.lang.Nullable;
@ -44,9 +45,13 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -70,15 +75,21 @@ import org.springframework.security.oauth2.server.resource.introspection.OAuth2I
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.reactive.server.MockServerConfigurer;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.test.web.reactive.server.WebTestClientConfigurer;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolver;
import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer;
import org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import static java.lang.Boolean.TRUE;
import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB;
/**
@ -1121,9 +1132,18 @@ public class SecurityMockServerConfigurers {
private Consumer<List<WebFilter>> addAuthorizedClientFilter() {
OAuth2AuthorizedClient client = getClient();
return filters -> filters.add(0, (exchange, chain) ->
authorizedClientRepository.saveAuthorizedClient(client, null, exchange)
.then(chain.filter(exchange)));
return filters -> filters.add(0, (exchange, chain) -> {
ReactiveOAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServerTestUtils
.getOAuth2AuthorizedClientManager(exchange);
if (!(authorizationClientManager instanceof TestReactiveOAuth2AuthorizedClientManager)) {
authorizationClientManager =
new TestReactiveOAuth2AuthorizedClientManager(authorizationClientManager);
OAuth2ClientServerTestUtils.setOAuth2AuthorizedClientManager(exchange, authorizationClientManager);
}
TestReactiveOAuth2AuthorizedClientManager.enable(exchange);
exchange.getAttributes().put(TestReactiveOAuth2AuthorizedClientManager.TOKEN_ATTR_NAME, client);
return chain.filter(exchange);
});
}
private OAuth2AuthorizedClient getClient() {
@ -1141,5 +1161,136 @@ public class SecurityMockServerConfigurers {
.clientSecret("test-secret")
.tokenUri("https://idp.example.org/oauth/token");
}
/**
* Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for testing when the
* request is wrapped
*/
private static class TestReactiveOAuth2AuthorizedClientManager
implements ReactiveOAuth2AuthorizedClientManager {
final static String TOKEN_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class
.getName().concat(".TOKEN");
final static String ENABLED_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class
.getName().concat(".ENABLED");
private final ReactiveOAuth2AuthorizedClientManager delegate;
private TestReactiveOAuth2AuthorizedClientManager(ReactiveOAuth2AuthorizedClientManager delegate) {
this.delegate = delegate;
}
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
ServerWebExchange exchange =
authorizeRequest.getAttribute(ServerWebExchange.class.getName());
if (isEnabled(exchange)) {
OAuth2AuthorizedClient client = exchange.getAttribute(TOKEN_ATTR_NAME);
return Mono.just(client);
} else {
return this.delegate.authorize(authorizeRequest);
}
}
public static void enable(ServerWebExchange exchange) {
exchange.getAttributes().put(ENABLED_ATTR_NAME, TRUE);
}
public boolean isEnabled(ServerWebExchange exchange) {
return TRUE.equals(exchange.getAttribute(ENABLED_ATTR_NAME));
}
}
private static class OAuth2ClientServerTestUtils {
private static final ServerOAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO =
new WebSessionServerOAuth2AuthorizedClientRepository();
/**
* Gets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified {@link ServerWebExchange}.
* If one is not found, one based off of {@link WebSessionServerOAuth2AuthorizedClientRepository} is used.
*
* @param exchange the {@link ServerWebExchange} to obtain the
* {@link ReactiveOAuth2AuthorizedClientManager}
* @return the {@link ReactiveOAuth2AuthorizedClientManager} for the specified
* {@link ServerWebExchange}
*/
public static ReactiveOAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(ServerWebExchange exchange) {
OAuth2AuthorizedClientArgumentResolver resolver =
findResolver(exchange, OAuth2AuthorizedClientArgumentResolver.class);
if (resolver == null) {
return authorizeRequest -> DEFAULT_CLIENT_REPO.loadAuthorizedClient
(authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), exchange);
}
return (ReactiveOAuth2AuthorizedClientManager)
ReflectionTestUtils.getField(resolver, "authorizedClientManager");
}
/**
* Sets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified {@link ServerWebExchange}.
*
* @param exchange the {@link ServerWebExchange} to obtain the
* {@link ReactiveOAuth2AuthorizedClientManager}
* @param manager the {@link ReactiveOAuth2AuthorizedClientManager} to set
*/
public static void setOAuth2AuthorizedClientManager(ServerWebExchange exchange,
ReactiveOAuth2AuthorizedClientManager manager) {
OAuth2AuthorizedClientArgumentResolver resolver =
findResolver(exchange, OAuth2AuthorizedClientArgumentResolver.class);
if (resolver == null) {
return;
}
ReflectionTestUtils.setField(resolver, "authorizedClientManager", manager);
}
@SuppressWarnings("unchecked")
static <T extends HandlerMethodArgumentResolver> T findResolver(ServerWebExchange exchange,
Class<T> resolverClass) {
if (!ClassUtils.isPresent
("org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter", null)) {
return null;
}
return WebFluxClasspathGuard.findResolver(exchange, resolverClass);
}
private static class WebFluxClasspathGuard {
static <T extends HandlerMethodArgumentResolver> T findResolver(ServerWebExchange exchange,
Class<T> resolverClass) {
RequestMappingHandlerAdapter handlerAdapter = getRequestMappingHandlerAdapter(exchange);
if (handlerAdapter == null) {
return null;
}
ArgumentResolverConfigurer configurer = handlerAdapter.getArgumentResolverConfigurer();
if (configurer == null) {
return null;
}
List<HandlerMethodArgumentResolver> resolvers = (List<HandlerMethodArgumentResolver>)
ReflectionTestUtils.invokeGetterMethod(configurer, "customResolvers");
if (resolvers == null) {
return null;
}
for (HandlerMethodArgumentResolver resolver : resolvers) {
if (resolverClass.isAssignableFrom(resolver.getClass())) {
return (T) resolver;
}
}
return null;
}
private static RequestMappingHandlerAdapter getRequestMappingHandlerAdapter(ServerWebExchange exchange) {
ApplicationContext context = exchange.getApplicationContext();
if (context != null) {
String[] names = context.getBeanNamesForType(RequestMappingHandlerAdapter.class);
if (names.length > 0) {
return (RequestMappingHandlerAdapter) context.getBean(names[0]);
}
}
return null;
}
}
private OAuth2ClientServerTestUtils() {
}
}
}
}

View File

@ -35,6 +35,7 @@ import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@ -56,11 +57,14 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
@ -89,10 +93,16 @@ import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.CsrfTokenRepository;
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.RequestPostProcessor;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.DigestUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter;
import static java.lang.Boolean.TRUE;
import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB;
@ -1657,9 +1667,16 @@ public final class SecurityMockMvcRequestPostProcessors {
}
OAuth2AuthorizedClient client = new OAuth2AuthorizedClient
(this.clientRegistration, this.principalName, this.accessToken);
OAuth2AuthorizedClientRepository authorizedClientRepository =
new HttpSessionOAuth2AuthorizedClientRepository();
authorizedClientRepository.saveAuthorizedClient(client, null, request, new MockHttpServletResponse());
OAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServletTestUtils
.getOAuth2AuthorizedClientManager(request);
if (!(authorizationClientManager instanceof TestOAuth2AuthorizedClientManager)) {
authorizationClientManager =
new TestOAuth2AuthorizedClientManager(authorizationClientManager);
OAuth2ClientServletTestUtils.setOAuth2AuthorizedClientManager(request, authorizationClientManager);
}
TestOAuth2AuthorizedClientManager.enable(request);
request.setAttribute(TestOAuth2AuthorizedClientManager.TOKEN_ATTR_NAME, client);
return request;
}
@ -1670,6 +1687,133 @@ public final class SecurityMockMvcRequestPostProcessors {
.clientSecret("test-secret")
.tokenUri("https://idp.example.org/oauth/token");
}
/**
* Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for testing when the
* request is wrapped
*/
private static class TestOAuth2AuthorizedClientManager
implements OAuth2AuthorizedClientManager {
final static String TOKEN_ATTR_NAME = TestOAuth2AuthorizedClientManager.class.getName()
.concat(".TOKEN");
final static String ENABLED_ATTR_NAME = TestOAuth2AuthorizedClientManager.class
.getName().concat(".ENABLED");
private final OAuth2AuthorizedClientManager delegate;
private TestOAuth2AuthorizedClientManager(OAuth2AuthorizedClientManager delegate) {
this.delegate = delegate;
}
@Override
public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) {
HttpServletRequest request =
authorizeRequest.getAttribute(HttpServletRequest.class.getName());
if (isEnabled(request)) {
return (OAuth2AuthorizedClient) request.getAttribute(TOKEN_ATTR_NAME);
} else {
return this.delegate.authorize(authorizeRequest);
}
}
public static void enable(HttpServletRequest request) {
request.setAttribute(ENABLED_ATTR_NAME, TRUE);
}
public boolean isEnabled(HttpServletRequest request) {
return TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME));
}
}
private static class OAuth2ClientServletTestUtils {
private static final OAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO =
new HttpSessionOAuth2AuthorizedClientRepository();
/**
* Gets the {@link OAuth2AuthorizedClientManager} for the specified {@link HttpServletRequest}.
* If one is not found, one based off of {@link HttpSessionOAuth2AuthorizedClientRepository} is used.
*
* @param request the {@link HttpServletRequest} to obtain the
* {@link OAuth2AuthorizedClientManager}
* @return the {@link OAuth2AuthorizedClientManager} for the specified
* {@link HttpServletRequest}
*/
public static OAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(HttpServletRequest request) {
OAuth2AuthorizedClientArgumentResolver resolver =
findResolver(request, OAuth2AuthorizedClientArgumentResolver.class);
if (resolver == null) {
return authorizeRequest -> DEFAULT_CLIENT_REPO.loadAuthorizedClient
(authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), request);
}
return (OAuth2AuthorizedClientManager)
ReflectionTestUtils.getField(resolver, "authorizedClientManager");
}
/**
* Sets the {@link OAuth2AuthorizedClientManager} for the specified {@link HttpServletRequest}.
*
* @param request the {@link HttpServletRequest} to obtain the
* {@link OAuth2AuthorizedClientManager}
* @param manager the {@link OAuth2AuthorizedClientManager} to set
*/
public static void setOAuth2AuthorizedClientManager(HttpServletRequest request,
OAuth2AuthorizedClientManager manager) {
OAuth2AuthorizedClientArgumentResolver resolver =
findResolver(request, OAuth2AuthorizedClientArgumentResolver.class);
if (resolver == null) {
return;
}
ReflectionTestUtils.setField(resolver, "authorizedClientManager", manager);
}
@SuppressWarnings("unchecked")
static <T extends HandlerMethodArgumentResolver> T findResolver(HttpServletRequest request,
Class<T> resolverClass) {
if (!ClassUtils.isPresent
("org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter", null)) {
return null;
}
return WebMvcClasspathGuard.findResolver(request, resolverClass);
}
private static class WebMvcClasspathGuard {
static <T extends HandlerMethodArgumentResolver> T findResolver(HttpServletRequest request,
Class<T> resolverClass) {
ServletContext servletContext = request.getServletContext();
RequestMappingHandlerAdapter mapping = getRequestMappingHandlerAdapter(servletContext);
if (mapping == null) {
return null;
}
List<HandlerMethodArgumentResolver> resolvers = mapping.getCustomArgumentResolvers();
if (resolvers == null) {
return null;
}
for (HandlerMethodArgumentResolver resolver : resolvers) {
if (resolverClass.isAssignableFrom(resolver.getClass())) {
return (T) resolver;
}
}
return null;
}
private static RequestMappingHandlerAdapter getRequestMappingHandlerAdapter(ServletContext servletContext) {
WebApplicationContext context = WebApplicationContextUtils
.getWebApplicationContext(servletContext);
if (context != null) {
String[] names = context.getBeanNamesForType(RequestMappingHandlerAdapter.class);
if (names.length > 0) {
return (RequestMappingHandlerAdapter) context.getBean(names[0]);
}
}
return null;
}
}
private OAuth2ClientServletTestUtils() {
}
}
}
private SecurityMockMvcRequestPostProcessors() {

View File

@ -21,26 +21,32 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.DispatcherHandler;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOAuth2Client;
@ -53,18 +59,18 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock
@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;
@Mock
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private WebTestClient client;
@Before
public void setup() {
ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
new WebSessionServerOAuth2AuthorizedClientRepository();
this.client = WebTestClient
.bindToController(this.controller)
.argumentResolvers(c -> c.addCustomResolver(
new OAuth2AuthorizedClientArgumentResolver
(this.clientRegistrationRepository, authorizedClientRepository)))
(this.clientRegistrationRepository, this.authorizedClientRepository)))
.webFilter(new SecurityContextServerWebExchangeWebFilter())
.apply(springSecurity())
.configureClient()
@ -162,6 +168,32 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock
assertThat(client.getRefreshToken()).isNull();
}
@Test
public void oauth2ClientWhenUsedOnceThenDoesNotAffectRemainingTests() throws Exception {
this.client.mutateWith(mockOAuth2Client("registration-id"))
.get().uri("/client")
.exchange()
.expectStatus().isOk();
OAuth2AuthorizedClient client = this.controller.authorizedClient;
assertThat(client).isNotNull();
assertThat(client.getClientRegistration().getClientId()).isEqualTo("test-client");
client = new OAuth2AuthorizedClient(clientRegistration().build(), "sub", noScopes());
when(this.authorizedClientRepository
.loadAuthorizedClient(eq("registration-id"), any(Authentication.class), any(ServerWebExchange.class)))
.thenReturn(Mono.just(client));
this.client
.get().uri("/client")
.exchange()
.expectStatus().isOk();
client = this.controller.authorizedClient;
assertThat(client).isNotNull();
assertThat(client.getClientRegistration().getClientId()).isEqualTo("client-id");
verify(this.authorizedClientRepository).loadAuthorizedClient(
eq("registration-id"), any(Authentication.class), any(ServerWebExchange.class));
}
@RestController
static class OAuth2LoginController {
volatile OAuth2AuthorizedClient authorizedClient;

View File

@ -36,7 +36,6 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
@ -55,18 +54,18 @@ public class SecurityMockServerConfigurersOAuth2LoginTests extends AbstractMockS
@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;
@Mock
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private WebTestClient client;
@Before
public void setup() {
ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
new WebSessionServerOAuth2AuthorizedClientRepository();
this.client = WebTestClient
.bindToController(this.controller)
.argumentResolvers(c -> c.addCustomResolver(
new OAuth2AuthorizedClientArgumentResolver
(this.clientRegistrationRepository, authorizedClientRepository)))
(this.clientRegistrationRepository, this.authorizedClientRepository)))
.webFilter(new SecurityContextServerWebExchangeWebFilter())
.apply(springSecurity())
.configureClient()

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,7 +35,6 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
@ -57,18 +56,18 @@ public class SecurityMockServerConfigurersOidcLoginTests extends AbstractMockSer
@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;
@Mock
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private WebTestClient client;
@Before
public void setup() {
ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
new WebSessionServerOAuth2AuthorizedClientRepository();
this.client = WebTestClient
.bindToController(this.controller)
.argumentResolvers(c -> c.addCustomResolver(
new OAuth2AuthorizedClientArgumentResolver
(this.clientRegistrationRepository, authorizedClientRepository)))
(this.clientRegistrationRepository, this.authorizedClientRepository)))
.webFilter(new SecurityContextServerWebExchangeWebFilter())
.apply(springSecurity())
.configureClient()

View File

@ -15,6 +15,8 @@
*/
package org.springframework.security.test.web.servlet.request;
import javax.servlet.http.HttpServletRequest;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@ -26,11 +28,11 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.test.context.TestSecurityContextHolder;
@ -45,7 +47,11 @@ import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oauth2Client;
@ -138,6 +144,22 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests {
.andExpect(content().string("no-scopes"));
}
@Test
public void oauth2ClientWhenUsedOnceThenDoesNotAffectRemainingTests() throws Exception {
this.mvc.perform(get("/client-id")
.with(oauth2Client("registration-id")))
.andExpect(content().string("test-client"));
OAuth2AuthorizedClient client = new OAuth2AuthorizedClient(clientRegistration().build(), "sub", noScopes());
OAuth2AuthorizedClientRepository repository = this.context.getBean(OAuth2AuthorizedClientRepository.class);
when(repository.loadAuthorizedClient(eq("registration-id"), any(Authentication.class), any(HttpServletRequest.class)))
.thenReturn(client);
this.mvc.perform(get("/client-id"))
.andExpect(content().string("client-id"));
verify(repository).loadAuthorizedClient(
eq("registration-id"), any(Authentication.class), any(HttpServletRequest.class));
}
@EnableWebSecurity
@EnableWebMvc
static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter {
@ -158,7 +180,7 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests {
@Bean
OAuth2AuthorizedClientRepository authorizedClientRepository() {
return new HttpSessionOAuth2AuthorizedClientRepository();
return mock(OAuth2AuthorizedClientRepository.class);
}
@RestController

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -37,7 +37,6 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
@ -182,7 +181,7 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2LoginTests {
@Bean
OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository() {
return new HttpSessionOAuth2AuthorizedClientRepository();
return mock(OAuth2AuthorizedClientRepository.class);
}
@RestController

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,7 +36,6 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
@ -190,7 +189,7 @@ public class SecurityMockMvcRequestPostProcessorsOidcLoginTests {
@Bean
OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository() {
return new HttpSessionOAuth2AuthorizedClientRepository();
return mock(OAuth2AuthorizedClientRepository.class);
}
@RestController