OAuth2AuthorizationCodeGrantWebFilter matches on registered redirect-uri

Fixes gh-7036
This commit is contained in:
Joe Grandja 2019-09-23 16:51:12 -04:00
parent 6f6f5a12da
commit 9f18c2e21a
4 changed files with 182 additions and 36 deletions

View File

@ -31,6 +31,8 @@ import java.util.UUID;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.util.context.Context; import reactor.util.context.Context;
@ -231,6 +233,7 @@ import static org.springframework.security.web.server.util.matcher.ServerWebExch
* @author Vedran Pavic * @author Vedran Pavic
* @author Rafiullah Hamedy * @author Rafiullah Hamedy
* @author Eddú Meléndez * @author Eddú Meléndez
* @author Joe Grandja
* @since 5.0 * @since 5.0
*/ */
public class ServerHttpSecurity { public class ServerHttpSecurity {
@ -1317,6 +1320,8 @@ public class ServerHttpSecurity {
private ReactiveAuthenticationManager authenticationManager; private ReactiveAuthenticationManager authenticationManager;
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
/** /**
* Sets the converter to use * Sets the converter to use
* @param authenticationConverter the converter to use * @param authenticationConverter the converter to use
@ -1329,7 +1334,10 @@ public class ServerHttpSecurity {
private ServerAuthenticationConverter getAuthenticationConverter() { private ServerAuthenticationConverter getAuthenticationConverter() {
if (this.authenticationConverter == null) { if (this.authenticationConverter == null) {
this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(getClientRegistrationRepository()); ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter =
new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(getClientRegistrationRepository());
authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
this.authenticationConverter = authenticationConverter;
} }
return this.authenticationConverter; return this.authenticationConverter;
} }
@ -1378,6 +1386,26 @@ public class ServerHttpSecurity {
return this; return this;
} }
/**
* Sets the repository to use for storing {@link OAuth2AuthorizationRequest}'s.
*
* @since 5.2
* @param authorizationRequestRepository the repository to use for storing {@link OAuth2AuthorizationRequest}'s
* @return the {@link OAuth2ClientSpec} to customize
*/
public OAuth2ClientSpec authorizationRequestRepository(
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository) {
this.authorizationRequestRepository = authorizationRequestRepository;
return this;
}
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> getAuthorizationRequestRepository() {
if (this.authorizationRequestRepository == null) {
this.authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
}
return this.authorizationRequestRepository;
}
/** /**
* Allows method chaining to continue configuring the {@link ServerHttpSecurity} * Allows method chaining to continue configuring the {@link ServerHttpSecurity}
* @return the {@link ServerHttpSecurity} to continue configuring * @return the {@link ServerHttpSecurity} to continue configuring
@ -1391,12 +1419,13 @@ public class ServerHttpSecurity {
ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository(); ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository();
ServerAuthenticationConverter authenticationConverter = getAuthenticationConverter(); ServerAuthenticationConverter authenticationConverter = getAuthenticationConverter();
ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); ReactiveAuthenticationManager authenticationManager = getAuthenticationManager();
OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter(authenticationManager, OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter(
authenticationConverter, authenticationManager, authenticationConverter, authorizedClientRepository);
authorizedClientRepository); codeGrantWebFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter( OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(
clientRegistrationRepository); clientRegistrationRepository);
oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE); http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE);
http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
} }

View File

@ -34,11 +34,16 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners;
import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.SecurityWebFilterChain;
@ -140,19 +145,33 @@ public class OAuth2ClientSpecTests {
ServerAuthenticationConverter converter = config.authenticationConverter; ServerAuthenticationConverter converter = config.authenticationConverter;
ReactiveAuthenticationManager manager = config.manager; ReactiveAuthenticationManager manager = config.manager;
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = config.authorizationRequestRepository;
OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.redirectUri("/authorize/oauth2/code/registration-id")
.build();
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
.redirectUri("/authorize/oauth2/code/registration-id")
.build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse);
OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes();
OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(this.registration, exchange, accessToken); OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(
this.registration, authorizationExchange, accessToken);
when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest));
when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c")));
when(manager.authenticate(any())).thenReturn(Mono.just(result)); when(manager.authenticate(any())).thenReturn(Mono.just(result));
this.client.get() this.client.get()
.uri("/authorize/oauth2/code/registration-id") .uri(uriBuilder ->
.exchange() uriBuilder.path("/authorize/oauth2/code/registration-id")
.expectStatus().is3xxRedirection(); .queryParam(OAuth2ParameterNames.CODE, "code")
.queryParam(OAuth2ParameterNames.STATE, "state")
.build())
.exchange()
.expectStatus().is3xxRedirection();
verify(converter).convert(any()); verify(converter).convert(any());
verify(manager).authenticate(any()); verify(manager).authenticate(any());
@ -176,12 +195,15 @@ public class OAuth2ClientSpecTests {
ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class);
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
@Bean @Bean
public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
http http
.oauth2Client() .oauth2Client()
.authenticationConverter(this.authenticationConverter) .authenticationConverter(this.authenticationConverter)
.authenticationManager(this.manager); .authenticationManager(this.manager)
.authorizationRequestRepository(this.authorizationRequestRepository);
return http.build(); return http.build();
} }
} }
@ -194,17 +216,31 @@ public class OAuth2ClientSpecTests {
ServerAuthenticationConverter converter = config.authenticationConverter; ServerAuthenticationConverter converter = config.authenticationConverter;
ReactiveAuthenticationManager manager = config.manager; ReactiveAuthenticationManager manager = config.manager;
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = config.authorizationRequestRepository;
OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.redirectUri("/authorize/oauth2/code/registration-id")
.build();
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
.redirectUri("/authorize/oauth2/code/registration-id")
.build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse);
OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes();
OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(this.registration, exchange, accessToken); OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(
this.registration, authorizationExchange, accessToken);
when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest));
when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c")));
when(manager.authenticate(any())).thenReturn(Mono.just(result)); when(manager.authenticate(any())).thenReturn(Mono.just(result));
this.client.get() this.client.get()
.uri("/authorize/oauth2/code/registration-id") .uri(uriBuilder ->
uriBuilder.path("/authorize/oauth2/code/registration-id")
.queryParam(OAuth2ParameterNames.CODE, "code")
.queryParam(OAuth2ParameterNames.STATE, "state")
.build())
.exchange() .exchange()
.expectStatus().is3xxRedirection(); .expectStatus().is3xxRedirection();
@ -218,6 +254,8 @@ public class OAuth2ClientSpecTests {
ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class);
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
@Bean @Bean
public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
http http
@ -225,6 +263,7 @@ public class OAuth2ClientSpecTests {
oauth2Client oauth2Client
.authenticationConverter(this.authenticationConverter) .authenticationConverter(this.authenticationConverter)
.authenticationManager(this.manager) .authenticationManager(this.manager)
.authorizationRequestRepository(this.authorizationRequestRepository)
); );
return http.build(); return http.build();
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -35,12 +35,13 @@ import org.springframework.security.web.server.authentication.RedirectServerAuth
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
/** /**
@ -71,6 +72,7 @@ import reactor.core.publisher.Mono;
* </ul> * </ul>
* *
* @author Rob Winch * @author Rob Winch
* @author Joe Grandja
* @since 5.1 * @since 5.1
* @see OAuth2AuthorizationCodeAuthenticationToken * @see OAuth2AuthorizationCodeAuthenticationToken
* @see org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager * @see org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager
@ -89,10 +91,15 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new WebSessionOAuth2ServerAuthorizationRequestRepository();
private ServerAuthenticationSuccessHandler authenticationSuccessHandler; private ServerAuthenticationSuccessHandler authenticationSuccessHandler;
private ServerAuthenticationConverter authenticationConverter; private ServerAuthenticationConverter authenticationConverter;
private boolean defaultAuthenticationConverter;
private ServerAuthenticationFailureHandler authenticationFailureHandler; private ServerAuthenticationFailureHandler authenticationFailureHandler;
private ServerWebExchangeMatcher requiresAuthenticationMatcher; private ServerWebExchangeMatcher requiresAuthenticationMatcher;
@ -109,8 +116,12 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.authenticationManager = authenticationManager; this.authenticationManager = authenticationManager;
this.authorizedClientRepository = authorizedClientRepository; this.authorizedClientRepository = authorizedClientRepository;
this.requiresAuthenticationMatcher = new PathPatternParserServerWebExchangeMatcher("/{action}/oauth2/code/{registrationId}"); this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse;
this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter =
new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
authenticationConverter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
this.authenticationConverter = authenticationConverter;
this.defaultAuthenticationConverter = true;
this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception);
} }
@ -124,12 +135,33 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.authenticationManager = authenticationManager; this.authenticationManager = authenticationManager;
this.authorizedClientRepository = authorizedClientRepository; this.authorizedClientRepository = authorizedClientRepository;
this.requiresAuthenticationMatcher = new PathPatternParserServerWebExchangeMatcher("/{action}/oauth2/code/{registrationId}"); this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse;
this.authenticationConverter = authenticationConverter; this.authenticationConverter = authenticationConverter;
this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception);
} }
/**
* Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s.
* The default is {@link WebSessionOAuth2ServerAuthorizationRequestRepository}.
*
* @since 5.2
* @param authorizationRequestRepository the repository used for storing {@link OAuth2AuthorizationRequest}'s
*/
public final void setAuthorizationRequestRepository(
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository) {
Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null");
this.authorizationRequestRepository = authorizationRequestRepository;
updateDefaultAuthenticationConverter();
}
private void updateDefaultAuthenticationConverter() {
if (this.defaultAuthenticationConverter) {
((ServerOAuth2AuthorizationCodeAuthenticationTokenConverter) this.authenticationConverter)
.setAuthorizationRequestRepository(this.authorizationRequestRepository);
}
}
@Override @Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return this.requiresAuthenticationMatcher.matches(exchange) return this.requiresAuthenticationMatcher.matches(exchange)
@ -164,4 +196,22 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, webFilterExchange.getExchange())) .flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, webFilterExchange.getExchange()))
); );
} }
private Mono<ServerWebExchangeMatcher.MatchResult> matchesAuthorizationResponse(ServerWebExchange exchange) {
return this.authorizationRequestRepository.loadAuthorizationRequest(exchange)
.flatMap(authorizationRequest -> {
String requestUrl = UriComponentsBuilder.fromUri(exchange.getRequest().getURI())
.query(null)
.build()
.toUriString();
MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
if (requestUrl.equals(authorizationRequest.getRedirectUri()) &&
OAuth2AuthorizationResponseUtils.isAuthorizationResponse(queryParams)) {
return ServerWebExchangeMatcher.MatchResult.match();
}
return ServerWebExchangeMatcher.MatchResult.notMatch();
})
.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
.switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch());
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,16 +28,22 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens; import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
import org.springframework.web.server.handler.DefaultWebFilterChain; import org.springframework.web.server.handler.DefaultWebFilterChain;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.*;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
/** /**
* @author Rob Winch * @author Rob Winch
@ -52,12 +58,16 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
private ReactiveClientRegistrationRepository clientRegistrationRepository; private ReactiveClientRegistrationRepository clientRegistrationRepository;
@Mock @Mock
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
@Mock
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
@Before @Before
public void setup() { public void setup() {
this.filter = new OAuth2AuthorizationCodeGrantWebFilter( this.filter = new OAuth2AuthorizationCodeGrantWebFilter(
this.authenticationManager, this.clientRegistrationRepository, this.authenticationManager, this.clientRegistrationRepository,
this.authorizedClientRepository); this.authorizedClientRepository);
when(this.authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.empty());
this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
} }
@Test @Test
@ -101,25 +111,43 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
@Test @Test
public void filterWhenMatchThenAuthorizedClientSaved() { public void filterWhenMatchThenAuthorizedClientSaved() {
Mono<Authentication> authentication = Mono OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.just(TestOAuth2AuthorizationCodeAuthenticationTokens.unauthenticated()); .redirectUri("/authorize/registration-id")
.build();
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
.redirectUri("/authorize/registration-id")
.build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse);
ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
Mono<Authentication> authentication = Mono.just(
new OAuth2AuthorizationCodeAuthenticationToken(registration, authorizationExchange));
OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens
.authenticated(); .authenticated();
ServerAuthenticationConverter converter = e -> authentication;
this.filter = new OAuth2AuthorizationCodeGrantWebFilter( when(this.authenticationManager.authenticate(any())).thenReturn(
this.authenticationManager, converter, this.authorizedClientRepository); Mono.just(authenticated));
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
.get("/authorize/oauth2/code/registration-id")); .thenReturn(Mono.just(authorizationRequest));
DefaultWebFilterChain chain = new DefaultWebFilterChain(
e -> e.getResponse().setComplete());
when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just(
authenticated));
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
.thenReturn(Mono.empty()); .thenReturn(Mono.empty());
ServerAuthenticationConverter converter = e -> authentication;
this.filter = new OAuth2AuthorizationCodeGrantWebFilter(
this.authenticationManager, converter, this.authorizedClientRepository);
this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
MockServerHttpRequest request = MockServerHttpRequest
.get("/authorize/registration-id")
.queryParam(OAuth2ParameterNames.CODE, "code")
.queryParam(OAuth2ParameterNames.STATE, "state")
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
DefaultWebFilterChain chain = new DefaultWebFilterChain(
e -> e.getResponse().setComplete());
this.filter.filter(exchange, chain).block(); this.filter.filter(exchange, chain).block();
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any());
} }
} }