mirror of
				https://github.com/spring-projects/spring-security.git
				synced 2025-10-30 22:28:46 +00:00 
			
		
		
		
	Add ServerRequestCache setter in OAuth2AuthorizationCodeGrantWebFilter
Fixes gh-8536
This commit is contained in:
		
							parent
							
								
									aa84c79e87
								
							
						
					
					
						commit
						28d2cfa14a
					
				| @ -236,6 +236,7 @@ import static org.springframework.security.web.server.util.matcher.ServerWebExch | ||||
|  * @author Rafiullah Hamedy | ||||
|  * @author Eddú Meléndez | ||||
|  * @author Joe Grandja | ||||
|  * @author Parikshit Dutta | ||||
|  * @since 5.0 | ||||
|  */ | ||||
| public class ServerHttpSecurity { | ||||
| @ -1511,10 +1512,17 @@ public class ServerHttpSecurity { | ||||
| 			OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter( | ||||
| 					authenticationManager, authenticationConverter, authorizedClientRepository); | ||||
| 			codeGrantWebFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); | ||||
| 			if (http.requestCache != null) { | ||||
| 				codeGrantWebFilter.setRequestCache(http.requestCache.requestCache); | ||||
| 			} | ||||
| 
 | ||||
| 			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter( | ||||
| 					clientRegistrationRepository); | ||||
| 			oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); | ||||
| 			if (http.requestCache != null) { | ||||
| 				oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); | ||||
| 			} | ||||
| 
 | ||||
| 			http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE); | ||||
| 			http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); | ||||
| 		} | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| /* | ||||
|  * Copyright 2002-2019 the original author or authors. | ||||
|  * Copyright 2002-2020 the original author or authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
| @ -16,6 +16,8 @@ | ||||
| 
 | ||||
| package org.springframework.security.config.web.server; | ||||
| 
 | ||||
| import java.net.URI; | ||||
| 
 | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.runner.RunWith; | ||||
| @ -48,6 +50,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio | ||||
| import org.springframework.security.test.context.support.WithMockUser; | ||||
| import org.springframework.security.web.server.SecurityWebFilterChain; | ||||
| import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; | ||||
| import org.springframework.security.web.server.savedrequest.ServerRequestCache; | ||||
| import org.springframework.test.context.junit4.SpringRunner; | ||||
| import org.springframework.test.web.reactive.server.WebTestClient; | ||||
| import org.springframework.web.bind.annotation.GetMapping; | ||||
| @ -62,6 +65,7 @@ import static org.mockito.Mockito.when; | ||||
| 
 | ||||
| /** | ||||
|  * @author Rob Winch | ||||
|  * @author Parikshit Dutta | ||||
|  * @since 5.1 | ||||
|  */ | ||||
| @RunWith(SpringRunner.class) | ||||
| @ -146,6 +150,7 @@ public class OAuth2ClientSpecTests { | ||||
| 		ServerAuthenticationConverter converter = config.authenticationConverter; | ||||
| 		ReactiveAuthenticationManager manager = config.manager; | ||||
| 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = config.authorizationRequestRepository; | ||||
| 		ServerRequestCache requestCache = config.requestCache; | ||||
| 
 | ||||
| 		OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() | ||||
| 				.redirectUri("/authorize/oauth2/code/registration-id") | ||||
| @ -163,6 +168,7 @@ public class OAuth2ClientSpecTests { | ||||
| 		when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); | ||||
| 		when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); | ||||
| 		when(manager.authenticate(any())).thenReturn(Mono.just(result)); | ||||
| 		when(requestCache.getRedirectUri(any())).thenReturn(Mono.just(URI.create("/saved-request"))); | ||||
| 
 | ||||
| 		this.client.get() | ||||
| 				.uri(uriBuilder -> | ||||
| @ -175,6 +181,7 @@ public class OAuth2ClientSpecTests { | ||||
| 
 | ||||
| 		verify(converter).convert(any()); | ||||
| 		verify(manager).authenticate(any()); | ||||
| 		verify(requestCache).getRedirectUri(any()); | ||||
| 	} | ||||
| 
 | ||||
| 	@EnableWebFlux | ||||
| @ -197,13 +204,17 @@ public class OAuth2ClientSpecTests { | ||||
| 
 | ||||
| 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); | ||||
| 
 | ||||
| 		ServerRequestCache requestCache = mock(ServerRequestCache.class); | ||||
| 
 | ||||
| 		@Bean | ||||
| 		public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { | ||||
| 			http | ||||
| 				.oauth2Client() | ||||
| 					.authenticationConverter(this.authenticationConverter) | ||||
| 					.authenticationManager(this.manager) | ||||
| 					.authorizationRequestRepository(this.authorizationRequestRepository); | ||||
| 					.authorizationRequestRepository(this.authorizationRequestRepository) | ||||
| 					.and() | ||||
| 				.requestCache(c -> c.requestCache(this.requestCache)); | ||||
| 			return http.build(); | ||||
| 		} | ||||
| 	} | ||||
| @ -217,6 +228,7 @@ public class OAuth2ClientSpecTests { | ||||
| 		ServerAuthenticationConverter converter = config.authenticationConverter; | ||||
| 		ReactiveAuthenticationManager manager = config.manager; | ||||
| 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = config.authorizationRequestRepository; | ||||
| 		ServerRequestCache requestCache = config.requestCache; | ||||
| 
 | ||||
| 		OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() | ||||
| 				.redirectUri("/authorize/oauth2/code/registration-id") | ||||
| @ -234,6 +246,7 @@ public class OAuth2ClientSpecTests { | ||||
| 		when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); | ||||
| 		when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c"))); | ||||
| 		when(manager.authenticate(any())).thenReturn(Mono.just(result)); | ||||
| 		when(requestCache.getRedirectUri(any())).thenReturn(Mono.just(URI.create("/saved-request"))); | ||||
| 
 | ||||
| 		this.client.get() | ||||
| 				.uri(uriBuilder -> | ||||
| @ -246,6 +259,7 @@ public class OAuth2ClientSpecTests { | ||||
| 
 | ||||
| 		verify(converter).convert(any()); | ||||
| 		verify(manager).authenticate(any()); | ||||
| 		verify(requestCache).getRedirectUri(any()); | ||||
| 	} | ||||
| 
 | ||||
| 	@Configuration | ||||
| @ -256,6 +270,8 @@ public class OAuth2ClientSpecTests { | ||||
| 
 | ||||
| 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); | ||||
| 
 | ||||
| 		ServerRequestCache requestCache = mock(ServerRequestCache.class); | ||||
| 
 | ||||
| 		@Bean | ||||
| 		public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { | ||||
| 			http | ||||
| @ -263,8 +279,8 @@ public class OAuth2ClientSpecTests { | ||||
| 					oauth2Client | ||||
| 						.authenticationConverter(this.authenticationConverter) | ||||
| 						.authenticationManager(this.manager) | ||||
| 						.authorizationRequestRepository(this.authorizationRequestRepository) | ||||
| 				); | ||||
| 						.authorizationRequestRepository(this.authorizationRequestRepository)) | ||||
| 				.requestCache(c -> c.requestCache(this.requestCache)); | ||||
| 			return http.build(); | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @ -35,6 +35,8 @@ import org.springframework.security.web.server.authentication.RedirectServerAuth | ||||
| import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; | ||||
| import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; | ||||
| import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; | ||||
| import org.springframework.security.web.server.savedrequest.ServerRequestCache; | ||||
| import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; | ||||
| import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; | ||||
| import org.springframework.util.Assert; | ||||
| import org.springframework.web.server.ServerWebExchange; | ||||
| @ -80,6 +82,7 @@ import java.util.Set; | ||||
|  * | ||||
|  * @author Rob Winch | ||||
|  * @author Joe Grandja | ||||
|  * @author Parikshit Dutta | ||||
|  * @since 5.1 | ||||
|  * @see OAuth2AuthorizationCodeAuthenticationToken | ||||
|  * @see org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager | ||||
| @ -111,6 +114,8 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { | ||||
| 
 | ||||
| 	private ServerWebExchangeMatcher requiresAuthenticationMatcher; | ||||
| 
 | ||||
| 	private ServerRequestCache requestCache = new WebSessionServerRequestCache(); | ||||
| 
 | ||||
| 	private AnonymousAuthenticationToken anonymousToken = new AnonymousAuthenticationToken("key", "anonymous", | ||||
| 					AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); | ||||
| 
 | ||||
| @ -129,7 +134,10 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { | ||||
| 		authenticationConverter.setAuthorizationRequestRepository(this.authorizationRequestRepository); | ||||
| 		this.authenticationConverter = authenticationConverter; | ||||
| 		this.defaultAuthenticationConverter = true; | ||||
| 		this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); | ||||
| 		RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler = | ||||
| 				new RedirectServerAuthenticationSuccessHandler(); | ||||
| 		authenticationSuccessHandler.setRequestCache(this.requestCache); | ||||
| 		this.authenticationSuccessHandler = authenticationSuccessHandler; | ||||
| 		this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); | ||||
| 	} | ||||
| 
 | ||||
| @ -144,7 +152,10 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { | ||||
| 		this.authorizedClientRepository = authorizedClientRepository; | ||||
| 		this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse; | ||||
| 		this.authenticationConverter = authenticationConverter; | ||||
| 		this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler(); | ||||
| 		RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler = | ||||
| 				new RedirectServerAuthenticationSuccessHandler(); | ||||
| 		authenticationSuccessHandler.setRequestCache(this.requestCache); | ||||
| 		this.authenticationSuccessHandler = authenticationSuccessHandler; | ||||
| 		this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception); | ||||
| 	} | ||||
| 
 | ||||
| @ -169,6 +180,23 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	/** | ||||
| 	 * Sets the {@link ServerRequestCache} used for loading a previously saved request (if available) | ||||
| 	 * and replaying it after completing the processing of the OAuth 2.0 Authorization Response. | ||||
| 	 * | ||||
| 	 * @since 5.4 | ||||
| 	 * @param requestCache the cache used for loading a previously saved request (if available) | ||||
| 	 */ | ||||
| 	public final void setRequestCache(ServerRequestCache requestCache) { | ||||
| 		Assert.notNull(requestCache, "requestCache cannot be null"); | ||||
| 		this.requestCache = requestCache; | ||||
| 		updateDefaultAuthenticationSuccessHandler(); | ||||
| 	} | ||||
| 
 | ||||
| 	private void updateDefaultAuthenticationSuccessHandler() { | ||||
| 		((RedirectServerAuthenticationSuccessHandler) this.authenticationSuccessHandler).setRequestCache(this.requestCache); | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { | ||||
| 		return this.requiresAuthenticationMatcher.matches(exchange) | ||||
|  | ||||
| @ -31,17 +31,22 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg | ||||
| import org.springframework.security.oauth2.client.registration.TestClientRegistrations; | ||||
| import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; | ||||
| import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; | ||||
| import org.springframework.security.web.server.savedrequest.ServerRequestCache; | ||||
| import org.springframework.util.CollectionUtils; | ||||
| import org.springframework.web.server.ServerWebExchange; | ||||
| import org.springframework.web.server.handler.DefaultWebFilterChain; | ||||
| import reactor.core.publisher.Mono; | ||||
| 
 | ||||
| import java.net.URI; | ||||
| import java.util.Collections; | ||||
| import java.util.HashMap; | ||||
| import java.util.LinkedHashMap; | ||||
| import java.util.Map; | ||||
| 
 | ||||
| import static org.assertj.core.api.Assertions.assertThat; | ||||
| import static org.assertj.core.api.Assertions.assertThatCode; | ||||
| import static org.mockito.ArgumentMatchers.any; | ||||
| import static org.mockito.Mockito.mock; | ||||
| import static org.mockito.Mockito.times; | ||||
| import static org.mockito.Mockito.verify; | ||||
| import static org.mockito.Mockito.verifyNoInteractions; | ||||
| @ -50,6 +55,7 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author | ||||
| 
 | ||||
| /** | ||||
|  * @author Rob Winch | ||||
|  * @author Parikshit Dutta | ||||
|  * @since 5.1 | ||||
|  */ | ||||
| @RunWith(MockitoJUnitRunner.class) | ||||
| @ -99,6 +105,12 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { | ||||
| 				.isInstanceOf(IllegalArgumentException.class); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() { | ||||
| 		assertThatCode(() -> this.filter.setRequestCache(null)) | ||||
| 				.isInstanceOf(IllegalArgumentException.class); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void filterWhenNotMatchThenAuthenticationManagerNotCalled() { | ||||
| 		MockServerWebExchange exchange = MockServerWebExchange | ||||
| @ -233,6 +245,40 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { | ||||
| 		verifyNoInteractions(this.authenticationManager); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void filterWhenAuthorizationSucceedsAndRequestCacheConfiguredThenRequestCacheUsed() { | ||||
| 		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); | ||||
| 		when(this.clientRegistrationRepository.findByRegistrationId(any())) | ||||
| 				.thenReturn(Mono.just(clientRegistration)); | ||||
| 		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) | ||||
| 				.thenReturn(Mono.empty()); | ||||
| 		when(this.authenticationManager.authenticate(any())) | ||||
| 				.thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); | ||||
| 
 | ||||
| 		MockServerHttpRequest authorizationRequest = createAuthorizationRequest("/authorization/callback"); | ||||
| 		OAuth2AuthorizationRequest oauth2AuthorizationRequest = | ||||
| 				createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); | ||||
| 		when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) | ||||
| 				.thenReturn(Mono.just(oauth2AuthorizationRequest)); | ||||
| 		when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) | ||||
| 				.thenReturn(Mono.just(oauth2AuthorizationRequest)); | ||||
| 
 | ||||
| 		MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); | ||||
| 		MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); | ||||
| 		DefaultWebFilterChain chain = new DefaultWebFilterChain( | ||||
| 				e -> e.getResponse().setComplete(), Collections.emptyList()); | ||||
| 
 | ||||
| 		ServerRequestCache requestCache = mock(ServerRequestCache.class); | ||||
| 		when(requestCache.getRedirectUri(any(ServerWebExchange.class))).thenReturn(Mono.just(URI.create("/saved-request"))); | ||||
| 
 | ||||
| 		this.filter.setRequestCache(requestCache); | ||||
| 
 | ||||
| 		this.filter.filter(exchange, chain).block(); | ||||
| 
 | ||||
| 		verify(requestCache).getRedirectUri(exchange); | ||||
| 		assertThat(exchange.getResponse().getHeaders().getLocation().toString()).isEqualTo("/saved-request"); | ||||
| 	} | ||||
| 
 | ||||
| 	private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest( | ||||
| 			MockServerHttpRequest authorizationRequest, ClientRegistration registration) { | ||||
| 		Map<String, Object> attributes = new HashMap<>(); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user