From 1749c8df9c688e8bed3f72ee099eaf46df8c0187 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 23 Sep 2019 16:51:12 -0400 Subject: [PATCH] OAuth2AuthorizationCodeGrantWebFilter matches on registered redirect-uri Fixes gh-7036 --- .../web/server/OAuth2ClientSpecTests.java | 46 +++++++++++--- ...OAuth2AuthorizationCodeGrantWebFilter.java | 31 ++++++++-- ...2AuthorizationCodeGrantWebFilterTests.java | 60 +++++++++++++------ 3 files changed, 108 insertions(+), 29 deletions(-) diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java index 9982ca608a..1755d8e427 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,11 +34,17 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; -import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.web.server.SecurityWebFilterChain; @@ -69,8 +75,11 @@ public class OAuth2ClientSpecTests { private ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); + private ApplicationContext context; + @Autowired public void setApplicationContext(ApplicationContext context) { + this.context = context; this.client = WebTestClient.bindToApplicationContext(context).build(); } @@ -140,19 +149,40 @@ public class OAuth2ClientSpecTests { ServerAuthenticationConverter converter = config.authenticationConverter; ReactiveAuthenticationManager manager = config.manager; + ServerAuthorizationRequestRepository authorizationRequestRepository = + new WebSessionOAuth2ServerAuthorizationRequestRepository(); - OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .redirectUri("/authorize/oauth2/code/registration-id") + .build(); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() + .redirectUri("/authorize/oauth2/code/registration-id") + .build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); - OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(this.registration, exchange, accessToken); + OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken( + this.registration, authorizationExchange, accessToken); when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); when(manager.authenticate(any())).thenReturn(Mono.just(result)); - this.client.get() - .uri("/authorize/oauth2/code/registration-id") - .exchange() - .expectStatus().is3xxRedirection(); + WebTestClient client = WebTestClient.bindToApplicationContext(this.context) + .webFilter((exchange, chain) -> + authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, exchange) + .then(chain.filter(exchange).then(Mono.empty())) + ) + .build(); + + client.get() + .uri(uriBuilder -> + uriBuilder.path("/authorize/oauth2/code/registration-id") + .queryParam(OAuth2ParameterNames.CODE, "code") + .queryParam(OAuth2ParameterNames.STATE, "state") + .build()) + .exchange() + .expectStatus().is3xxRedirection(); verify(converter).convert(any()); verify(manager).authenticate(any()); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java index 02e76b1a1a..ef489ab842 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,12 +35,13 @@ import org.springframework.security.web.server.authentication.RedirectServerAuth import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; -import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; +import org.springframework.web.util.UriComponentsBuilder; import reactor.core.publisher.Mono; /** @@ -71,6 +72,7 @@ import reactor.core.publisher.Mono; * * * @author Rob Winch + * @author Joe Grandja * @since 5.1 * @see OAuth2AuthorizationCodeAuthenticationToken * @see org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager @@ -89,6 +91,9 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ServerAuthorizationRequestRepository authorizationRequestRepository = + new WebSessionOAuth2ServerAuthorizationRequestRepository(); + private ServerAuthenticationSuccessHandler authenticationSuccessHandler; private ServerAuthenticationConverter authenticationConverter; @@ -109,7 +114,7 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.authenticationManager = authenticationManager; this.authorizedClientRepository = authorizedClientRepository; - this.requiresAuthenticationMatcher = new PathPatternParserServerWebExchangeMatcher("/{action}/oauth2/code/{registrationId}"); + this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse; this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); @@ -124,7 +129,7 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.authenticationManager = authenticationManager; this.authorizedClientRepository = authorizedClientRepository; - this.requiresAuthenticationMatcher = new PathPatternParserServerWebExchangeMatcher("/{action}/oauth2/code/{registrationId}"); + this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse; this.authenticationConverter = authenticationConverter; this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); @@ -164,4 +169,22 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { .flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, webFilterExchange.getExchange())) ); } + + private Mono matchesAuthorizationResponse(ServerWebExchange exchange) { + return this.authorizationRequestRepository.loadAuthorizationRequest(exchange) + .flatMap(authorizationRequest -> { + String requestUrl = UriComponentsBuilder.fromUri(exchange.getRequest().getURI()) + .query(null) + .build() + .toUriString(); + MultiValueMap queryParams = exchange.getRequest().getQueryParams(); + if (requestUrl.equals(authorizationRequest.getRedirectUri()) && + OAuth2AuthorizationResponseUtils.isAuthorizationResponse(queryParams)) { + return ServerWebExchangeMatcher.MatchResult.match(); + } + return ServerWebExchangeMatcher.MatchResult.notMatch(); + }) + .filter(ServerWebExchangeMatcher.MatchResult::isMatch) + .switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch()); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java index 5727a26fa3..1124f6d40a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,16 +28,22 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.web.server.handler.DefaultWebFilterChain; import reactor.core.publisher.Mono; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; /** * @author Rob Winch @@ -53,6 +59,9 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ServerAuthorizationRequestRepository authorizationRequestRepository = + new WebSessionOAuth2ServerAuthorizationRequestRepository(); + @Before public void setup() { this.filter = new OAuth2AuthorizationCodeGrantWebFilter( @@ -101,25 +110,42 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { @Test public void filterWhenMatchThenAuthorizedClientSaved() { - Mono authentication = Mono - .just(TestOAuth2AuthorizationCodeAuthenticationTokens.unauthenticated()); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .redirectUri("/authorize/registration-id") + .build(); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() + .redirectUri("/authorize/registration-id") + .build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); + ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); + Mono authentication = Mono.just( + new OAuth2AuthorizationCodeAuthenticationToken(registration, authorizationExchange)); OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens .authenticated(); - ServerAuthenticationConverter converter = e -> authentication; - this.filter = new OAuth2AuthorizationCodeGrantWebFilter( - this.authenticationManager, converter, this.authorizedClientRepository); - MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest - .get("/authorize/oauth2/code/registration-id")); - DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete()); - when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just( - authenticated)); + + when(this.authenticationManager.authenticate(any())).thenReturn( + Mono.just(authenticated)); when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) .thenReturn(Mono.empty()); + ServerAuthenticationConverter converter = e -> authentication; + + this.filter = new OAuth2AuthorizationCodeGrantWebFilter( + this.authenticationManager, converter, this.authorizedClientRepository); + + MockServerHttpRequest request = MockServerHttpRequest + .get("/authorize/registration-id") + .queryParam(OAuth2ParameterNames.CODE, "code") + .queryParam(OAuth2ParameterNames.STATE, "state") + .build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + DefaultWebFilterChain chain = new DefaultWebFilterChain( + e -> e.getResponse().setComplete()); + + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, exchange).block(); this.filter.filter(exchange, chain).block(); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any()); - } }