diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 405121fbf7..2b36f0390d 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -31,6 +31,8 @@ import java.util.UUID; import java.util.function.Function; import java.util.function.Supplier; +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; +import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository; import reactor.core.publisher.Mono; import reactor.util.context.Context; @@ -231,6 +233,7 @@ import static org.springframework.security.web.server.util.matcher.ServerWebExch * @author Vedran Pavic * @author Rafiullah Hamedy * @author EddĂș MelĂ©ndez + * @author Joe Grandja * @since 5.0 */ public class ServerHttpSecurity { @@ -1317,6 +1320,8 @@ public class ServerHttpSecurity { private ReactiveAuthenticationManager authenticationManager; + private ServerAuthorizationRequestRepository authorizationRequestRepository; + /** * Sets the converter to use * @param authenticationConverter the converter to use @@ -1329,7 +1334,10 @@ public class ServerHttpSecurity { private ServerAuthenticationConverter getAuthenticationConverter() { if (this.authenticationConverter == null) { - this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(getClientRegistrationRepository()); + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = + new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(getClientRegistrationRepository()); + authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + this.authenticationConverter = authenticationConverter; } return this.authenticationConverter; } @@ -1378,6 +1386,26 @@ public class ServerHttpSecurity { return this; } + /** + * Sets the repository to use for storing {@link OAuth2AuthorizationRequest}'s. + * + * @since 5.2 + * @param authorizationRequestRepository the repository to use for storing {@link OAuth2AuthorizationRequest}'s + * @return the {@link OAuth2ClientSpec} to customize + */ + public OAuth2ClientSpec authorizationRequestRepository( + ServerAuthorizationRequestRepository authorizationRequestRepository) { + this.authorizationRequestRepository = authorizationRequestRepository; + return this; + } + + private ServerAuthorizationRequestRepository getAuthorizationRequestRepository() { + if (this.authorizationRequestRepository == null) { + this.authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + } + return this.authorizationRequestRepository; + } + /** * Allows method chaining to continue configuring the {@link ServerHttpSecurity} * @return the {@link ServerHttpSecurity} to continue configuring @@ -1391,12 +1419,13 @@ public class ServerHttpSecurity { ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository(); ServerAuthenticationConverter authenticationConverter = getAuthenticationConverter(); ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); - OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter(authenticationManager, - authenticationConverter, - authorizedClientRepository); + OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter( + authenticationManager, authenticationConverter, authorizedClientRepository); + codeGrantWebFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter( clientRegistrationRepository); + oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java index 6fec4cf502..4e07a65d78 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java @@ -34,11 +34,16 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; -import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.web.server.SecurityWebFilterChain; @@ -140,19 +145,33 @@ public class OAuth2ClientSpecTests { ServerAuthenticationConverter converter = config.authenticationConverter; ReactiveAuthenticationManager manager = config.manager; + ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; - OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .redirectUri("/authorize/oauth2/code/registration-id") + .build(); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() + .redirectUri("/authorize/oauth2/code/registration-id") + .build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); - OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(this.registration, exchange, accessToken); + OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken( + this.registration, authorizationExchange, accessToken); + when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); when(manager.authenticate(any())).thenReturn(Mono.just(result)); this.client.get() - .uri("/authorize/oauth2/code/registration-id") - .exchange() - .expectStatus().is3xxRedirection(); + .uri(uriBuilder -> + uriBuilder.path("/authorize/oauth2/code/registration-id") + .queryParam(OAuth2ParameterNames.CODE, "code") + .queryParam(OAuth2ParameterNames.STATE, "state") + .build()) + .exchange() + .expectStatus().is3xxRedirection(); verify(converter).convert(any()); verify(manager).authenticate(any()); @@ -176,12 +195,15 @@ public class OAuth2ClientSpecTests { ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); + ServerAuthorizationRequestRepository authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); + @Bean public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { http .oauth2Client() .authenticationConverter(this.authenticationConverter) - .authenticationManager(this.manager); + .authenticationManager(this.manager) + .authorizationRequestRepository(this.authorizationRequestRepository); return http.build(); } } @@ -194,17 +216,31 @@ public class OAuth2ClientSpecTests { ServerAuthenticationConverter converter = config.authenticationConverter; ReactiveAuthenticationManager manager = config.manager; + ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; - OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success(); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .redirectUri("/authorize/oauth2/code/registration-id") + .build(); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() + .redirectUri("/authorize/oauth2/code/registration-id") + .build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes(); - OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(this.registration, exchange, accessToken); + OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken( + this.registration, authorizationExchange, accessToken); + when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); when(manager.authenticate(any())).thenReturn(Mono.just(result)); this.client.get() - .uri("/authorize/oauth2/code/registration-id") + .uri(uriBuilder -> + uriBuilder.path("/authorize/oauth2/code/registration-id") + .queryParam(OAuth2ParameterNames.CODE, "code") + .queryParam(OAuth2ParameterNames.STATE, "state") + .build()) .exchange() .expectStatus().is3xxRedirection(); @@ -218,6 +254,8 @@ public class OAuth2ClientSpecTests { ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); + ServerAuthorizationRequestRepository authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); + @Bean public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { http @@ -225,6 +263,7 @@ public class OAuth2ClientSpecTests { oauth2Client .authenticationConverter(this.authenticationConverter) .authenticationManager(this.manager) + .authorizationRequestRepository(this.authorizationRequestRepository) ); return http.build(); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java index 02e76b1a1a..de6e31d458 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,12 +35,13 @@ import org.springframework.security.web.server.authentication.RedirectServerAuth import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; -import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; +import org.springframework.web.util.UriComponentsBuilder; import reactor.core.publisher.Mono; /** @@ -71,6 +72,7 @@ import reactor.core.publisher.Mono; * * * @author Rob Winch + * @author Joe Grandja * @since 5.1 * @see OAuth2AuthorizationCodeAuthenticationToken * @see org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager @@ -89,10 +91,15 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ServerAuthorizationRequestRepository authorizationRequestRepository = + new WebSessionOAuth2ServerAuthorizationRequestRepository(); + private ServerAuthenticationSuccessHandler authenticationSuccessHandler; private ServerAuthenticationConverter authenticationConverter; + private boolean defaultAuthenticationConverter; + private ServerAuthenticationFailureHandler authenticationFailureHandler; private ServerWebExchangeMatcher requiresAuthenticationMatcher; @@ -109,8 +116,12 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.authenticationManager = authenticationManager; this.authorizedClientRepository = authorizedClientRepository; - this.requiresAuthenticationMatcher = new PathPatternParserServerWebExchangeMatcher("/{action}/oauth2/code/{registrationId}"); - this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); + this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse; + ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = + new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository); + authenticationConverter.setAuthorizationRequestRepository(this.authorizationRequestRepository); + this.authenticationConverter = authenticationConverter; + this.defaultAuthenticationConverter = true; this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); } @@ -124,12 +135,33 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.authenticationManager = authenticationManager; this.authorizedClientRepository = authorizedClientRepository; - this.requiresAuthenticationMatcher = new PathPatternParserServerWebExchangeMatcher("/{action}/oauth2/code/{registrationId}"); + this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse; this.authenticationConverter = authenticationConverter; this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); } + /** + * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. + * The default is {@link WebSessionOAuth2ServerAuthorizationRequestRepository}. + * + * @since 5.2 + * @param authorizationRequestRepository the repository used for storing {@link OAuth2AuthorizationRequest}'s + */ + public final void setAuthorizationRequestRepository( + ServerAuthorizationRequestRepository authorizationRequestRepository) { + Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); + this.authorizationRequestRepository = authorizationRequestRepository; + updateDefaultAuthenticationConverter(); + } + + private void updateDefaultAuthenticationConverter() { + if (this.defaultAuthenticationConverter) { + ((ServerOAuth2AuthorizationCodeAuthenticationTokenConverter) this.authenticationConverter) + .setAuthorizationRequestRepository(this.authorizationRequestRepository); + } + } + @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return this.requiresAuthenticationMatcher.matches(exchange) @@ -164,4 +196,22 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { .flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, webFilterExchange.getExchange())) ); } + + private Mono matchesAuthorizationResponse(ServerWebExchange exchange) { + return this.authorizationRequestRepository.loadAuthorizationRequest(exchange) + .flatMap(authorizationRequest -> { + String requestUrl = UriComponentsBuilder.fromUri(exchange.getRequest().getURI()) + .query(null) + .build() + .toUriString(); + MultiValueMap queryParams = exchange.getRequest().getQueryParams(); + if (requestUrl.equals(authorizationRequest.getRedirectUri()) && + OAuth2AuthorizationResponseUtils.isAuthorizationResponse(queryParams)) { + return ServerWebExchangeMatcher.MatchResult.match(); + } + return ServerWebExchangeMatcher.MatchResult.notMatch(); + }) + .filter(ServerWebExchangeMatcher.MatchResult::isMatch) + .switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch()); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java index 5727a26fa3..f86256cbe0 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,16 +28,22 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.web.server.handler.DefaultWebFilterChain; import reactor.core.publisher.Mono; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; /** * @author Rob Winch @@ -52,12 +58,16 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { private ReactiveClientRegistrationRepository clientRegistrationRepository; @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + @Mock + private ServerAuthorizationRequestRepository authorizationRequestRepository; @Before public void setup() { this.filter = new OAuth2AuthorizationCodeGrantWebFilter( this.authenticationManager, this.clientRegistrationRepository, this.authorizedClientRepository); + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.empty()); + this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository); } @Test @@ -101,25 +111,43 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { @Test public void filterWhenMatchThenAuthorizedClientSaved() { - Mono authentication = Mono - .just(TestOAuth2AuthorizationCodeAuthenticationTokens.unauthenticated()); + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() + .redirectUri("/authorize/registration-id") + .build(); + OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() + .redirectUri("/authorize/registration-id") + .build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); + ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); + Mono authentication = Mono.just( + new OAuth2AuthorizationCodeAuthenticationToken(registration, authorizationExchange)); OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens .authenticated(); - ServerAuthenticationConverter converter = e -> authentication; - this.filter = new OAuth2AuthorizationCodeGrantWebFilter( - this.authenticationManager, converter, this.authorizedClientRepository); - MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest - .get("/authorize/oauth2/code/registration-id")); - DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete()); - when(this.authenticationManager.authenticate(any())).thenReturn(Mono.just( - authenticated)); + + when(this.authenticationManager.authenticate(any())).thenReturn( + Mono.just(authenticated)); + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .thenReturn(Mono.just(authorizationRequest)); when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) .thenReturn(Mono.empty()); + ServerAuthenticationConverter converter = e -> authentication; + + this.filter = new OAuth2AuthorizationCodeGrantWebFilter( + this.authenticationManager, converter, this.authorizedClientRepository); + this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository); + + MockServerHttpRequest request = MockServerHttpRequest + .get("/authorize/registration-id") + .queryParam(OAuth2ParameterNames.CODE, "code") + .queryParam(OAuth2ParameterNames.STATE, "state") + .build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + DefaultWebFilterChain chain = new DefaultWebFilterChain( + e -> e.getResponse().setComplete()); this.filter.filter(exchange, chain).block(); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any()); - } }