diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 7f2275c9a6..65ff5edb38 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -236,6 +236,7 @@ import static org.springframework.security.web.server.util.matcher.ServerWebExch * @author Rafiullah Hamedy * @author EddĂș MelĂ©ndez * @author Joe Grandja + * @author Parikshit Dutta * @since 5.0 */ public class ServerHttpSecurity { @@ -1511,10 +1512,17 @@ public class ServerHttpSecurity { OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter( authenticationManager, authenticationConverter, authorizedClientRepository); codeGrantWebFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + if (http.requestCache != null) { + codeGrantWebFilter.setRequestCache(http.requestCache.requestCache); + } OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter( clientRegistrationRepository); oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + if (http.requestCache != null) { + oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); + } + http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java index 4e07a65d78..0ea9c446da 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.security.config.web.server; +import java.net.URI; + import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -48,6 +50,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; @@ -62,6 +65,7 @@ import static org.mockito.Mockito.when; /** * @author Rob Winch + * @author Parikshit Dutta * @since 5.1 */ @RunWith(SpringRunner.class) @@ -146,6 +150,7 @@ public class OAuth2ClientSpecTests { ServerAuthenticationConverter converter = config.authenticationConverter; ReactiveAuthenticationManager manager = config.manager; ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; + ServerRequestCache requestCache = config.requestCache; OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .redirectUri("/authorize/oauth2/code/registration-id") @@ -163,6 +168,7 @@ public class OAuth2ClientSpecTests { when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); when(manager.authenticate(any())).thenReturn(Mono.just(result)); + when(requestCache.getRedirectUri(any())).thenReturn(Mono.just(URI.create("/saved-request"))); this.client.get() .uri(uriBuilder -> @@ -175,6 +181,7 @@ public class OAuth2ClientSpecTests { verify(converter).convert(any()); verify(manager).authenticate(any()); + verify(requestCache).getRedirectUri(any()); } @EnableWebFlux @@ -197,13 +204,17 @@ public class OAuth2ClientSpecTests { ServerAuthorizationRequestRepository authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); + ServerRequestCache requestCache = mock(ServerRequestCache.class); + @Bean public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { http .oauth2Client() .authenticationConverter(this.authenticationConverter) .authenticationManager(this.manager) - .authorizationRequestRepository(this.authorizationRequestRepository); + .authorizationRequestRepository(this.authorizationRequestRepository) + .and() + .requestCache(c -> c.requestCache(this.requestCache)); return http.build(); } } @@ -217,6 +228,7 @@ public class OAuth2ClientSpecTests { ServerAuthenticationConverter converter = config.authenticationConverter; ReactiveAuthenticationManager manager = config.manager; ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; + ServerRequestCache requestCache = config.requestCache; OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .redirectUri("/authorize/oauth2/code/registration-id") @@ -234,6 +246,7 @@ public class OAuth2ClientSpecTests { when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); when(manager.authenticate(any())).thenReturn(Mono.just(result)); + when(requestCache.getRedirectUri(any())).thenReturn(Mono.just(URI.create("/saved-request"))); this.client.get() .uri(uriBuilder -> @@ -246,6 +259,7 @@ public class OAuth2ClientSpecTests { verify(converter).convert(any()); verify(manager).authenticate(any()); + verify(requestCache).getRedirectUri(any()); } @Configuration @@ -256,6 +270,8 @@ public class OAuth2ClientSpecTests { ServerAuthorizationRequestRepository authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); + ServerRequestCache requestCache = mock(ServerRequestCache.class); + @Bean public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { http @@ -263,8 +279,8 @@ public class OAuth2ClientSpecTests { oauth2Client .authenticationConverter(this.authenticationConverter) .authenticationManager(this.manager) - .authorizationRequestRepository(this.authorizationRequestRepository) - ); + .authorizationRequestRepository(this.authorizationRequestRepository)) + .requestCache(c -> c.requestCache(this.requestCache)); return http.build(); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java index 7ef55667cb..59e019d97e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java @@ -35,6 +35,8 @@ import org.springframework.security.web.server.authentication.RedirectServerAuth import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; +import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; @@ -80,6 +82,7 @@ import java.util.Set; * * @author Rob Winch * @author Joe Grandja + * @author Parikshit Dutta * @since 5.1 * @see OAuth2AuthorizationCodeAuthenticationToken * @see org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager @@ -111,6 +114,8 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { private ServerWebExchangeMatcher requiresAuthenticationMatcher; + private ServerRequestCache requestCache = new WebSessionServerRequestCache(); + private AnonymousAuthenticationToken anonymousToken = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); @@ -129,7 +134,10 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { authenticationConverter.setAuthorizationRequestRepository(this.authorizationRequestRepository); this.authenticationConverter = authenticationConverter; this.defaultAuthenticationConverter = true; - this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); + RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler = + new RedirectServerAuthenticationSuccessHandler(); + authenticationSuccessHandler.setRequestCache(this.requestCache); + this.authenticationSuccessHandler = authenticationSuccessHandler; this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); } @@ -144,7 +152,10 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { this.authorizedClientRepository = authorizedClientRepository; this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse; this.authenticationConverter = authenticationConverter; - this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); + RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler = + new RedirectServerAuthenticationSuccessHandler(); + authenticationSuccessHandler.setRequestCache(this.requestCache); + this.authenticationSuccessHandler = authenticationSuccessHandler; this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); } @@ -169,6 +180,23 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { } } + /** + * Sets the {@link ServerRequestCache} used for loading a previously saved request (if available) + * and replaying it after completing the processing of the OAuth 2.0 Authorization Response. + * + * @since 5.4 + * @param requestCache the cache used for loading a previously saved request (if available) + */ + public final void setRequestCache(ServerRequestCache requestCache) { + Assert.notNull(requestCache, "requestCache cannot be null"); + this.requestCache = requestCache; + updateDefaultAuthenticationSuccessHandler(); + } + + private void updateDefaultAuthenticationSuccessHandler() { + ((RedirectServerAuthenticationSuccessHandler) this.authenticationSuccessHandler).setRequestCache(this.requestCache); + } + @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return this.requiresAuthenticationMatcher.matches(exchange) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java index 3c2a4bcc2d..54d8893e9e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java @@ -31,17 +31,22 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.util.CollectionUtils; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.handler.DefaultWebFilterChain; import reactor.core.publisher.Mono; +import java.net.URI; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -50,6 +55,7 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author /** * @author Rob Winch + * @author Parikshit Dutta * @since 5.1 */ @RunWith(MockitoJUnitRunner.class) @@ -99,6 +105,12 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() { + assertThatCode(() -> this.filter.setRequestCache(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void filterWhenNotMatchThenAuthenticationManagerNotCalled() { MockServerWebExchange exchange = MockServerWebExchange @@ -233,6 +245,40 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { verifyNoInteractions(this.authenticationManager); } + @Test + public void filterWhenAuthorizationSucceedsAndRequestCacheConfiguredThenRequestCacheUsed() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + when(this.clientRegistrationRepository.findByRegistrationId(any())) + .thenReturn(Mono.just(clientRegistration)); + when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) + .thenReturn(Mono.empty()); + when(this.authenticationManager.authenticate(any())) + .thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); + + MockServerHttpRequest authorizationRequest = createAuthorizationRequest("/authorization/callback"); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .thenReturn(Mono.just(oauth2AuthorizationRequest)); + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .thenReturn(Mono.just(oauth2AuthorizationRequest)); + + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); + DefaultWebFilterChain chain = new DefaultWebFilterChain( + e -> e.getResponse().setComplete(), Collections.emptyList()); + + ServerRequestCache requestCache = mock(ServerRequestCache.class); + when(requestCache.getRedirectUri(any(ServerWebExchange.class))).thenReturn(Mono.just(URI.create("/saved-request"))); + + this.filter.setRequestCache(requestCache); + + this.filter.filter(exchange, chain).block(); + + verify(requestCache).getRedirectUri(exchange); + assertThat(exchange.getResponse().getHeaders().getLocation().toString()).isEqualTo("/saved-request"); + } + private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest( MockServerHttpRequest authorizationRequest, ClientRegistration registration) { Map attributes = new HashMap<>();