OAuth2AuthorizationCodeGrantWebFilter matches on registered redirect-uri

Fixes gh-7036
This commit is contained in:
Joe Grandja 2019-09-23 16:51:12 -04:00
parent a5391b629e
commit 1749c8df9c
3 changed files with 108 additions and 29 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,11 +34,17 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners;
import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.security.web.server.SecurityWebFilterChain;
@ -69,8 +75,11 @@ public class OAuth2ClientSpecTests {
private ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
private ApplicationContext context;
@Autowired
public void setApplicationContext(ApplicationContext context) {
this.context = context;
this.client = WebTestClient.bindToApplicationContext(context).build();
}
@ -140,19 +149,40 @@ public class OAuth2ClientSpecTests {
ServerAuthenticationConverter converter = config.authenticationConverter;
ReactiveAuthenticationManager manager = config.manager;
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new WebSessionOAuth2ServerAuthorizationRequestRepository();
OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success();
OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.redirectUri("/authorize/oauth2/code/registration-id")
.build();
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
.redirectUri("/authorize/oauth2/code/registration-id")
.build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse);
OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes();
OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(this.registration, exchange, accessToken);
OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(
this.registration, authorizationExchange, accessToken);
when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c")));
when(manager.authenticate(any())).thenReturn(Mono.just(result));
this.client.get()
.uri("/authorize/oauth2/code/registration-id")
.exchange()
.expectStatus().is3xxRedirection();
WebTestClient client = WebTestClient.bindToApplicationContext(this.context)
.webFilter((exchange, chain) ->
authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, exchange)
.then(chain.filter(exchange).then(Mono.empty()))
)
.build();
client.get()
.uri(uriBuilder ->
uriBuilder.path("/authorize/oauth2/code/registration-id")
.queryParam(OAuth2ParameterNames.CODE, "code")
.queryParam(OAuth2ParameterNames.STATE, "state")
.build())
.exchange()
.expectStatus().is3xxRedirection();
verify(converter).convert(any());
verify(manager).authenticate(any());

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,12 +35,13 @@ import org.springframework.security.web.server.authentication.RedirectServerAuth
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono;
/**
@ -71,6 +72,7 @@ import reactor.core.publisher.Mono;
* </ul>
*
* @author Rob Winch
* @author Joe Grandja
* @since 5.1
* @see OAuth2AuthorizationCodeAuthenticationToken
* @see org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager
@ -89,6 +91,9 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new WebSessionOAuth2ServerAuthorizationRequestRepository();
private ServerAuthenticationSuccessHandler authenticationSuccessHandler;
private ServerAuthenticationConverter authenticationConverter;
@ -109,7 +114,7 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.authenticationManager = authenticationManager;
this.authorizedClientRepository = authorizedClientRepository;
this.requiresAuthenticationMatcher = new PathPatternParserServerWebExchangeMatcher("/{action}/oauth2/code/{registrationId}");
this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse;
this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception);
@ -124,7 +129,7 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.authenticationManager = authenticationManager;
this.authorizedClientRepository = authorizedClientRepository;
this.requiresAuthenticationMatcher = new PathPatternParserServerWebExchangeMatcher("/{action}/oauth2/code/{registrationId}");
this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse;
this.authenticationConverter = authenticationConverter;
this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception);
@ -164,4 +169,22 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, webFilterExchange.getExchange()))
);
}
private Mono<ServerWebExchangeMatcher.MatchResult> matchesAuthorizationResponse(ServerWebExchange exchange) {
return this.authorizationRequestRepository.loadAuthorizationRequest(exchange)
.flatMap(authorizationRequest -> {
String requestUrl = UriComponentsBuilder.fromUri(exchange.getRequest().getURI())
.query(null)
.build()
.toUriString();
MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
if (requestUrl.equals(authorizationRequest.getRedirectUri()) &&
OAuth2AuthorizationResponseUtils.isAuthorizationResponse(queryParams)) {
return ServerWebExchangeMatcher.MatchResult.match();
}
return ServerWebExchangeMatcher.MatchResult.notMatch();
})
.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
.switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch());
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,16 +28,22 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
import org.springframework.web.server.handler.DefaultWebFilterChain;
import reactor.core.publisher.Mono;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;
/**
* @author Rob Winch
@ -53,6 +59,9 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
@Mock
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new WebSessionOAuth2ServerAuthorizationRequestRepository();
@Before
public void setup() {
this.filter = new OAuth2AuthorizationCodeGrantWebFilter(
@ -101,25 +110,42 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
@Test
public void filterWhenMatchThenAuthorizedClientSaved() {
Mono<Authentication> authentication = Mono
.just(TestOAuth2AuthorizationCodeAuthenticationTokens.unauthenticated());
OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.redirectUri("/authorize/registration-id")
.build();
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
.redirectUri("/authorize/registration-id")
.build();
OAuth2AuthorizationExchange authorizationExchange =
new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse);
ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
Mono<Authentication> authentication = Mono.just(
new OAuth2AuthorizationCodeAuthenticationToken(registration, authorizationExchange));
OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens
.authenticated();
ServerAuthenticationConverter converter = e -> authentication;
this.filter = new OAuth2AuthorizationCodeGrantWebFilter(
this.authenticationManager, converter, this.authorizedClientRepository);
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest
.get("/authorize/oauth2/code/registration-id"));
DefaultWebFilterChain chain = new DefaultWebFilterChain(
e -> e.getResponse().setComplete());
when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just(
authenticated));
when(this.authenticationManager.authenticate(any())).thenReturn(
Mono.just(authenticated));
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
.thenReturn(Mono.empty());
ServerAuthenticationConverter converter = e -> authentication;
this.filter = new OAuth2AuthorizationCodeGrantWebFilter(
this.authenticationManager, converter, this.authorizedClientRepository);
MockServerHttpRequest request = MockServerHttpRequest
.get("/authorize/registration-id")
.queryParam(OAuth2ParameterNames.CODE, "code")
.queryParam(OAuth2ParameterNames.STATE, "state")
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
DefaultWebFilterChain chain = new DefaultWebFilterChain(
e -> e.getResponse().setComplete());
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, exchange).block();
this.filter.filter(exchange, chain).block();
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any());
}
}