mirror of
				https://github.com/spring-projects/spring-security.git
				synced 2025-10-30 22:28:46 +00:00 
			
		
		
		
	Use SecurityContextHolderStrategy for Messaging
Issue gh-11060
This commit is contained in:
		
							parent
							
								
									275586be5f
								
							
						
					
					
						commit
						1e498df39b
					
				| @ -32,6 +32,7 @@ import org.springframework.security.authorization.AuthorizationEventPublisher; | |||||||
| import org.springframework.security.authorization.AuthorizationManager; | import org.springframework.security.authorization.AuthorizationManager; | ||||||
| import org.springframework.security.core.Authentication; | import org.springframework.security.core.Authentication; | ||||||
| import org.springframework.security.core.context.SecurityContextHolder; | import org.springframework.security.core.context.SecurityContextHolder; | ||||||
|  | import org.springframework.security.core.context.SecurityContextHolderStrategy; | ||||||
| import org.springframework.util.Assert; | import org.springframework.util.Assert; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
| @ -42,14 +43,8 @@ import org.springframework.util.Assert; | |||||||
|  */ |  */ | ||||||
| public final class AuthorizationChannelInterceptor implements ChannelInterceptor { | public final class AuthorizationChannelInterceptor implements ChannelInterceptor { | ||||||
| 
 | 
 | ||||||
| 	static final Supplier<Authentication> AUTHENTICATION_SUPPLIER = () -> { | 	private Supplier<Authentication> authentication = getAuthentication( | ||||||
| 		Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); | 			SecurityContextHolder.getContextHolderStrategy()); | ||||||
| 		if (authentication == null) { |  | ||||||
| 			throw new AuthenticationCredentialsNotFoundException( |  | ||||||
| 					"An Authentication object was not found in the SecurityContext"); |  | ||||||
| 		} |  | ||||||
| 		return authentication; |  | ||||||
| 	}; |  | ||||||
| 
 | 
 | ||||||
| 	private final Log logger = LogFactory.getLog(this.getClass()); | 	private final Log logger = LogFactory.getLog(this.getClass()); | ||||||
| 
 | 
 | ||||||
| @ -71,8 +66,8 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor | |||||||
| 	@Override | 	@Override | ||||||
| 	public Message<?> preSend(Message<?> message, MessageChannel channel) { | 	public Message<?> preSend(Message<?> message, MessageChannel channel) { | ||||||
| 		this.logger.debug(LogMessage.of(() -> "Authorizing message send")); | 		this.logger.debug(LogMessage.of(() -> "Authorizing message send")); | ||||||
| 		AuthorizationDecision decision = this.preSendAuthorizationManager.check(AUTHENTICATION_SUPPLIER, message); | 		AuthorizationDecision decision = this.preSendAuthorizationManager.check(this.authentication, message); | ||||||
| 		this.eventPublisher.publishAuthorizationEvent(AUTHENTICATION_SUPPLIER, message, decision); | 		this.eventPublisher.publishAuthorizationEvent(this.authentication, message, decision); | ||||||
| 		if (decision == null || !decision.isGranted()) { // default deny | 		if (decision == null || !decision.isGranted()) { // default deny | ||||||
| 			this.logger.debug(LogMessage.of(() -> "Failed to authorize message with authorization manager " | 			this.logger.debug(LogMessage.of(() -> "Failed to authorize message with authorization manager " | ||||||
| 					+ this.preSendAuthorizationManager + " and decision " + decision)); | 					+ this.preSendAuthorizationManager + " and decision " + decision)); | ||||||
| @ -82,6 +77,14 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor | |||||||
| 		return message; | 		return message; | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use | ||||||
|  | 	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. | ||||||
|  | 	 */ | ||||||
|  | 	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { | ||||||
|  | 		this.authentication = getAuthentication(securityContextHolderStrategy); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	/** | 	/** | ||||||
| 	 * Use this {@link AuthorizationEventPublisher} to publish the | 	 * Use this {@link AuthorizationEventPublisher} to publish the | ||||||
| 	 * {@link AuthorizationManager} result. | 	 * {@link AuthorizationManager} result. | ||||||
| @ -92,6 +95,17 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor | |||||||
| 		this.eventPublisher = eventPublisher; | 		this.eventPublisher = eventPublisher; | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	private Supplier<Authentication> getAuthentication(SecurityContextHolderStrategy strategy) { | ||||||
|  | 		return () -> { | ||||||
|  | 			Authentication authentication = strategy.getContext().getAuthentication(); | ||||||
|  | 			if (authentication == null) { | ||||||
|  | 				throw new AuthenticationCredentialsNotFoundException( | ||||||
|  | 						"An Authentication object was not found in the SecurityContext"); | ||||||
|  | 			} | ||||||
|  | 			return authentication; | ||||||
|  | 		}; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	private static class NoopAuthorizationEventPublisher implements AuthorizationEventPublisher { | 	private static class NoopAuthorizationEventPublisher implements AuthorizationEventPublisher { | ||||||
| 
 | 
 | ||||||
| 		@Override | 		@Override | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| /* | /* | ||||||
|  * Copyright 2002-2021 the original author or authors. |  * Copyright 2002-2022 the original author or authors. | ||||||
|  * |  * | ||||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  * you may not use this file except in compliance with the License. |  * you may not use this file except in compliance with the License. | ||||||
| @ -29,7 +29,9 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentRes | |||||||
| import org.springframework.security.core.Authentication; | import org.springframework.security.core.Authentication; | ||||||
| import org.springframework.security.core.annotation.AuthenticationPrincipal; | import org.springframework.security.core.annotation.AuthenticationPrincipal; | ||||||
| import org.springframework.security.core.context.SecurityContextHolder; | import org.springframework.security.core.context.SecurityContextHolder; | ||||||
|  | import org.springframework.security.core.context.SecurityContextHolderStrategy; | ||||||
| import org.springframework.stereotype.Controller; | import org.springframework.stereotype.Controller; | ||||||
|  | import org.springframework.util.Assert; | ||||||
| import org.springframework.util.ClassUtils; | import org.springframework.util.ClassUtils; | ||||||
| import org.springframework.util.StringUtils; | import org.springframework.util.StringUtils; | ||||||
| 
 | 
 | ||||||
| @ -85,6 +87,9 @@ import org.springframework.util.StringUtils; | |||||||
|  */ |  */ | ||||||
| public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver { | public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver { | ||||||
| 
 | 
 | ||||||
|  | 	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder | ||||||
|  | 			.getContextHolderStrategy(); | ||||||
|  | 
 | ||||||
| 	private ExpressionParser parser = new SpelExpressionParser(); | 	private ExpressionParser parser = new SpelExpressionParser(); | ||||||
| 
 | 
 | ||||||
| 	@Override | 	@Override | ||||||
| @ -94,7 +99,7 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet | |||||||
| 
 | 
 | ||||||
| 	@Override | 	@Override | ||||||
| 	public Object resolveArgument(MethodParameter parameter, Message<?> message) { | 	public Object resolveArgument(MethodParameter parameter, Message<?> message) { | ||||||
| 		Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); | 		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); | ||||||
| 		if (authentication == null) { | 		if (authentication == null) { | ||||||
| 			return null; | 			return null; | ||||||
| 		} | 		} | ||||||
| @ -117,6 +122,17 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet | |||||||
| 		return principal; | 		return principal; | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use | ||||||
|  | 	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. | ||||||
|  | 	 * | ||||||
|  | 	 * @since 5.8 | ||||||
|  | 	 */ | ||||||
|  | 	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { | ||||||
|  | 		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); | ||||||
|  | 		this.securityContextHolderStrategy = securityContextHolderStrategy; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	/** | 	/** | ||||||
| 	 * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. | 	 * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. | ||||||
| 	 * @param annotationClass the class of the {@link Annotation} to find on the | 	 * @param annotationClass the class of the {@link Annotation} to find on the | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| /* | /* | ||||||
|  * Copyright 2002-2016 the original author or authors. |  * Copyright 2002-2022 the original author or authors. | ||||||
|  * |  * | ||||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  * you may not use this file except in compliance with the License. |  * you may not use this file except in compliance with the License. | ||||||
| @ -29,6 +29,7 @@ import org.springframework.security.core.Authentication; | |||||||
| import org.springframework.security.core.authority.AuthorityUtils; | import org.springframework.security.core.authority.AuthorityUtils; | ||||||
| import org.springframework.security.core.context.SecurityContext; | import org.springframework.security.core.context.SecurityContext; | ||||||
| import org.springframework.security.core.context.SecurityContextHolder; | import org.springframework.security.core.context.SecurityContextHolder; | ||||||
|  | import org.springframework.security.core.context.SecurityContextHolderStrategy; | ||||||
| import org.springframework.util.Assert; | import org.springframework.util.Assert; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
| @ -42,10 +43,13 @@ import org.springframework.util.Assert; | |||||||
|  */ |  */ | ||||||
| public final class SecurityContextChannelInterceptor implements ExecutorChannelInterceptor, ChannelInterceptor { | public final class SecurityContextChannelInterceptor implements ExecutorChannelInterceptor, ChannelInterceptor { | ||||||
| 
 | 
 | ||||||
| 	private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext(); |  | ||||||
| 
 |  | ||||||
| 	private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>(); | 	private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>(); | ||||||
| 
 | 
 | ||||||
|  | 	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder | ||||||
|  | 			.getContextHolderStrategy(); | ||||||
|  | 
 | ||||||
|  | 	private SecurityContext empty = this.securityContextHolderStrategy.createEmptyContext(); | ||||||
|  | 
 | ||||||
| 	private final String authenticationHeaderName; | 	private final String authenticationHeaderName; | ||||||
| 
 | 
 | ||||||
| 	private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", | 	private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", | ||||||
| @ -107,8 +111,13 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI | |||||||
| 		cleanup(); | 		cleanup(); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) { | ||||||
|  | 		this.securityContextHolderStrategy = strategy; | ||||||
|  | 		this.empty = this.securityContextHolderStrategy.createEmptyContext(); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	private void setup(Message<?> message) { | 	private void setup(Message<?> message) { | ||||||
| 		SecurityContext currentContext = SecurityContextHolder.getContext(); | 		SecurityContext currentContext = this.securityContextHolderStrategy.getContext(); | ||||||
| 		Stack<SecurityContext> contextStack = originalContext.get(); | 		Stack<SecurityContext> contextStack = originalContext.get(); | ||||||
| 		if (contextStack == null) { | 		if (contextStack == null) { | ||||||
| 			contextStack = new Stack<>(); | 			contextStack = new Stack<>(); | ||||||
| @ -117,9 +126,9 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI | |||||||
| 		contextStack.push(currentContext); | 		contextStack.push(currentContext); | ||||||
| 		Object user = message.getHeaders().get(this.authenticationHeaderName); | 		Object user = message.getHeaders().get(this.authenticationHeaderName); | ||||||
| 		Authentication authentication = getAuthentication(user); | 		Authentication authentication = getAuthentication(user); | ||||||
| 		SecurityContext context = SecurityContextHolder.createEmptyContext(); | 		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); | ||||||
| 		context.setAuthentication(authentication); | 		context.setAuthentication(authentication); | ||||||
| 		SecurityContextHolder.setContext(context); | 		this.securityContextHolderStrategy.setContext(context); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	private Authentication getAuthentication(Object user) { | 	private Authentication getAuthentication(Object user) { | ||||||
| @ -132,22 +141,22 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI | |||||||
| 	private void cleanup() { | 	private void cleanup() { | ||||||
| 		Stack<SecurityContext> contextStack = originalContext.get(); | 		Stack<SecurityContext> contextStack = originalContext.get(); | ||||||
| 		if (contextStack == null || contextStack.isEmpty()) { | 		if (contextStack == null || contextStack.isEmpty()) { | ||||||
| 			SecurityContextHolder.clearContext(); | 			this.securityContextHolderStrategy.clearContext(); | ||||||
| 			originalContext.remove(); | 			originalContext.remove(); | ||||||
| 			return; | 			return; | ||||||
| 		} | 		} | ||||||
| 		SecurityContext context = contextStack.pop(); | 		SecurityContext context = contextStack.pop(); | ||||||
| 		try { | 		try { | ||||||
| 			if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) { | 			if (SecurityContextChannelInterceptor.this.empty.equals(context)) { | ||||||
| 				SecurityContextHolder.clearContext(); | 				this.securityContextHolderStrategy.clearContext(); | ||||||
| 				originalContext.remove(); | 				originalContext.remove(); | ||||||
| 			} | 			} | ||||||
| 			else { | 			else { | ||||||
| 				SecurityContextHolder.setContext(context); | 				this.securityContextHolderStrategy.setContext(context); | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		catch (Throwable ex) { | 		catch (Throwable ex) { | ||||||
| 			SecurityContextHolder.clearContext(); | 			this.securityContextHolderStrategy.clearContext(); | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| /* | /* | ||||||
|  * Copyright 2002-2016 the original author or authors. |  * Copyright 2002-2022 the original author or authors. | ||||||
|  * |  * | ||||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  * you may not use this file except in compliance with the License. |  * you may not use this file except in compliance with the License. | ||||||
| @ -34,9 +34,13 @@ import org.springframework.security.authentication.TestingAuthenticationToken; | |||||||
| import org.springframework.security.core.Authentication; | import org.springframework.security.core.Authentication; | ||||||
| import org.springframework.security.core.authority.AuthorityUtils; | import org.springframework.security.core.authority.AuthorityUtils; | ||||||
| import org.springframework.security.core.context.SecurityContextHolder; | import org.springframework.security.core.context.SecurityContextHolder; | ||||||
|  | import org.springframework.security.core.context.SecurityContextHolderStrategy; | ||||||
|  | import org.springframework.security.core.context.SecurityContextImpl; | ||||||
| 
 | 
 | ||||||
| import static org.assertj.core.api.Assertions.assertThat; | import static org.assertj.core.api.Assertions.assertThat; | ||||||
| import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; | import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; | ||||||
|  | import static org.mockito.Mockito.spy; | ||||||
|  | import static org.mockito.Mockito.verify; | ||||||
| 
 | 
 | ||||||
| @ExtendWith(MockitoExtension.class) | @ExtendWith(MockitoExtension.class) | ||||||
| public class SecurityContextChannelInterceptorTests { | public class SecurityContextChannelInterceptorTests { | ||||||
| @ -94,6 +98,17 @@ public class SecurityContextChannelInterceptorTests { | |||||||
| 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); | 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void preSendWhenCustomSecurityContextHolderStrategyThenUserSet() { | ||||||
|  | 		SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); | ||||||
|  | 		strategy.setContext(new SecurityContextImpl(this.authentication)); | ||||||
|  | 		this.interceptor.setSecurityContextHolderStrategy(strategy); | ||||||
|  | 		this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); | ||||||
|  | 		this.interceptor.preSend(this.messageBuilder.build(), this.channel); | ||||||
|  | 		verify(strategy).getContext(); | ||||||
|  | 		assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	@Test | 	@Test | ||||||
| 	public void setAnonymousAuthenticationNull() { | 	public void setAnonymousAuthenticationNull() { | ||||||
| 		assertThatIllegalArgumentException().isThrownBy(() -> this.interceptor.setAnonymousAuthentication(null)); | 		assertThatIllegalArgumentException().isThrownBy(() -> this.interceptor.setAnonymousAuthentication(null)); | ||||||
| @ -143,6 +158,16 @@ public class SecurityContextChannelInterceptorTests { | |||||||
| 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); | 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void afterSendCompletionWhenCustomSecurityContextHolderStrategyThenNullAuthentication() { | ||||||
|  | 		SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); | ||||||
|  | 		strategy.setContext(new SecurityContextImpl(this.authentication)); | ||||||
|  | 		this.interceptor.setSecurityContextHolderStrategy(strategy); | ||||||
|  | 		this.interceptor.afterSendCompletion(this.messageBuilder.build(), this.channel, true, null); | ||||||
|  | 		verify(strategy).clearContext(); | ||||||
|  | 		assertThat(strategy.getContext().getAuthentication()).isNull(); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	@Test | 	@Test | ||||||
| 	public void beforeHandleUserSet() { | 	public void beforeHandleUserSet() { | ||||||
| 		this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); | 		this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); | ||||||
| @ -150,6 +175,17 @@ public class SecurityContextChannelInterceptorTests { | |||||||
| 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); | 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void beforeHandleWhenCustomSecurityContextHolderStrategyThenUserSet() { | ||||||
|  | 		SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); | ||||||
|  | 		strategy.setContext(new SecurityContextImpl(this.authentication)); | ||||||
|  | 		this.interceptor.setSecurityContextHolderStrategy(strategy); | ||||||
|  | 		this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); | ||||||
|  | 		this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); | ||||||
|  | 		verify(strategy).getContext(); | ||||||
|  | 		assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// SEC-2845 | 	// SEC-2845 | ||||||
| 	@Test | 	@Test | ||||||
| 	public void beforeHandleUserNotAuthentication() { | 	public void beforeHandleUserNotAuthentication() { | ||||||
| @ -178,6 +214,15 @@ public class SecurityContextChannelInterceptorTests { | |||||||
| 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); | 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void afterMessageHandledWhenCustomSecurityContextHolderStrategyThenUses() { | ||||||
|  | 		SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); | ||||||
|  | 		strategy.setContext(new SecurityContextImpl(this.authentication)); | ||||||
|  | 		this.interceptor.setSecurityContextHolderStrategy(strategy); | ||||||
|  | 		this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); | ||||||
|  | 		verify(strategy).clearContext(); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// SEC-2829 | 	// SEC-2829 | ||||||
| 	@Test | 	@Test | ||||||
| 	public void restoresOriginalContext() { | 	public void restoresOriginalContext() { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user