Allow to customize OAuth2AuthorizationRequestRedirectWebFilter in OAuth2LoginSpec

Fixes gh-7466
This commit is contained in:
Roman Chigvintsev 2019-09-23 14:12:16 +03:00 committed by Joe Grandja
parent 2a5bd6e719
commit 9bae0a4dbd
2 changed files with 108 additions and 3 deletions

View File

@ -76,6 +76,7 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
@ -972,6 +973,8 @@ public class ServerHttpSecurity {
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
private ReactiveAuthenticationManager authenticationManager;
private ServerSecurityContextRepository securityContextRepository;
@ -1102,6 +1105,18 @@ public class ServerHttpSecurity {
return this;
}
/**
* Sets authorization request repository for {@link OAuth2AuthorizationRequestRedirectWebFilter}.
*
* @param authorizationRequestRepository authorization request repository, must not be null
* @return the {@link OAuth2LoginSpec} for further configuration
*/
public OAuth2LoginSpec authorizationRequestRepository(ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository) {
Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null");
this.authorizationRequestRepository = authorizationRequestRepository;
return this;
}
/**
* Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
*
@ -1146,6 +1161,12 @@ public class ServerHttpSecurity {
ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository();
ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository();
OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter();
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
getAuthorizationRequestRepository();
if (authorizationRequestRepository != null) {
oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository);
}
oauthRedirectFilter.setRequestCache(http.requestCache.requestCache);
ReactiveAuthenticationManager manager = getAuthenticationManager();
@ -1246,6 +1267,14 @@ public class ServerHttpSecurity {
return result;
}
@SuppressWarnings("unchecked")
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> getAuthorizationRequestRepository() {
if (this.authorizationRequestRepository == null) {
this.authorizationRequestRepository = getBeanOrNull(ServerAuthorizationRequestRepository.class);
}
return this.authorizationRequestRepository;
}
private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() {
ReactiveOAuth2AuthorizedClientService service = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class);
if (service == null) {

View File

@ -20,12 +20,14 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@ -41,6 +43,7 @@ import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
import org.springframework.web.server.handler.FilteringWebHandler;
import reactor.core.publisher.Mono;
import reactor.test.publisher.TestPublisher;
@ -48,18 +51,29 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.WebFilterChainProxy;
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.reactive.server.EntityExchangeResult;
import org.springframework.test.web.reactive.server.FluxExchangeResult;
@ -68,10 +82,7 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
/**
* @author Rob Winch
@ -475,6 +486,71 @@ public class ServerHttpSecurityTests {
verify(customServerCsrfTokenRepository).loadToken(any());
}
@SuppressWarnings("UnassignedFluxMonoInstance")
@Test
public void configureOAuth2LoginUsingCustomCommonServerRequestCache() {
ServerRequestCache requestCacheMock = mock(ServerRequestCache.class);
when(requestCacheMock.saveRequest(any(ServerWebExchange.class))).thenReturn(Mono.empty());
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
String registrationId = clientRegistration.getRegistrationId();
ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
mock(ReactiveClientRegistrationRepository.class);
when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
.thenReturn(Mono.just(clientRegistration));
SecurityWebFilterChain filterChain = http.requestCache().requestCache(requestCacheMock)
.and().oauth2Login().clientRegistrationRepository(clientRegistrationRepositoryMock)
.and().build();
Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
assertThat(redirectWebFilter.isPresent()).isTrue();
FilteringWebHandler webHandler = new FilteringWebHandler(
e -> Mono.error(new ClientAuthorizationRequiredException(registrationId)),
Collections.singletonList(redirectWebFilter.get())
);
WebTestClient client = WebTestClient.bindToWebHandler(webHandler).build();
client.get().uri("/foo/bar").exchange();
verify(requestCacheMock, times(1)).saveRequest(any(ServerWebExchange.class));
}
@Test(expected = IllegalArgumentException.class)
public void throwExceptionWhenNullPassedForOAuth2LoginAuthorizationRequestRepository() {
http.oauth2Login().authorizationRequestRepository(null).and().build();
}
@SuppressWarnings({"UnassignedFluxMonoInstance", "unchecked"})
@Test
public void configureOAuth2LoginUsingCustomAuthorizationRequestRepository() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
String registrationId = clientRegistration.getRegistrationId();
ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
mock(ReactiveClientRegistrationRepository.class);
when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
.thenReturn(Mono.just(clientRegistration));
ServerAuthorizationRequestRepository requestRepositoryMock = mock(ServerAuthorizationRequestRepository.class);
SecurityWebFilterChain filterChain = http.oauth2Login()
.clientRegistrationRepository(clientRegistrationRepositoryMock)
.authorizationRequestRepository(requestRepositoryMock)
.and().build();
Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
assertThat(redirectWebFilter.isPresent()).isTrue();
WebTestClient client = WebTestClient.bindToController(new SubscriberContextController())
.webFilter(redirectWebFilter.get())
.build();
client.get().uri("/oauth2/authorization/" + registrationId).exchange();
verify(requestRepositoryMock, times(1)).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class),
any(ServerWebExchange.class));
}
private boolean isX509Filter(WebFilter filter) {
try {
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");