Add SecurityContextHolderStrategy to OAuth2

Issue gh-11060
This commit is contained in:
Josh Cummings 2022-06-21 16:33:43 -06:00
parent 6c16ac101a
commit 1d72a05c32
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
9 changed files with 138 additions and 16 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
@ -103,6 +104,9 @@ import org.springframework.web.util.UriComponentsBuilder;
*/
public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private final ClientRegistrationRepository clientRegistrationRepository;
private final OAuth2AuthorizedClientRepository authorizedClientRepository;
@ -158,6 +162,17 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
this.requestCache = requestCache;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
@ -232,7 +247,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
this.redirectStrategy.sendRedirect(request, response, uriBuilder.build().encode().toString());
return;
}
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
Authentication currentAuthentication = this.securityContextHolderStrategy.getContext().getAuthentication();
String principalName = (currentAuthentication != null) ? currentAuthentication.getName() : "anonymousUser";
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
authenticationResult.getClientRegistration(), principalName, authenticationResult.getAccessToken(),

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
@ -72,6 +73,9 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous",
"anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private OAuth2AuthorizedClientManager authorizedClientManager;
private boolean defaultAuthorizedClientManager;
@ -120,7 +124,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
+ "It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or "
+ "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
}
Authentication principal = SecurityContextHolder.getContext().getAuthentication();
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
if (principal == null) {
principal = ANONYMOUS_AUTHENTICATION;
}
@ -140,7 +144,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
private String resolveClientRegistrationId(MethodParameter parameter) {
RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils
.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
Authentication principal = SecurityContextHolder.getContext().getAuthentication();
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
if (!StringUtils.isEmpty(authorizedClientAnnotation.registrationId())) {
return authorizedClientAnnotation.registrationId();
}
@ -179,6 +183,17 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
updateDefaultAuthorizedClientManager(clientCredentialsTokenResponseClient);
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
private void updateDefaultAuthorizedClientManager(
OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
// @formatter:off

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,6 +38,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
@ -151,6 +152,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous",
"anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
@Deprecated
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
@ -304,6 +308,17 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
this.defaultClientRegistrationId = clientRegistrationId;
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/**
* Configures the builder with {@link #defaultRequest()} and adds this as a
* {@link ExchangeFilterFunction}
@ -513,7 +528,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) {
return;
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,6 +39,8 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
@ -306,6 +308,23 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
}
@Test
public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
this.setUpAuthenticationResult(this.registration1);
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.getContext())
.willReturn(new SecurityContextImpl(new TestingAuthenticationToken("user", "password")));
this.filter.setSecurityContextHolderStrategy(strategy);
this.filter.doFilter(authorizationResponse, response, filterChain);
verify(strategy).getContext();
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
}
@Test
public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception {
String requestUri = "/saved-request";

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,6 +34,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
@ -70,6 +72,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -254,6 +257,18 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1);
}
@Test
public void resolveArgumentWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.getContext()).willReturn(new SecurityContextImpl(this.authentication));
this.argumentResolver.setSecurityContextHolderStrategy(strategy);
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient",
OAuth2AuthorizedClient.class);
assertThat(this.argumentResolver.resolveArgument(methodParameter, null,
new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1);
verify(strategy, atLeastOnce()).getContext();
}
@Test
public void resolveArgumentWhenRegistrationIdInvalidThenThrowIllegalArgumentException() {
MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid",

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -65,6 +65,8 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
@ -282,6 +284,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
verifyNoInteractions(this.authorizedClientRepository);
}
@Test
public void defaultRequestAuthenticationWhenCustomSecurityContextHolderStrategyThenAuthenticationSet() {
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.getContext()).willReturn(new SecurityContextImpl(this.authentication));
this.function.setSecurityContextHolderStrategy(strategy);
Map<String, Object> attrs = getDefaultRequestAttributes();
assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs))
.isEqualTo(this.authentication);
verify(strategy).getContext();
verifyNoInteractions(this.authorizedClientRepository);
}
private Map<String, Object> getDefaultRequestAttributes() {
this.function.defaultRequest().accept(this.spec);
verify(this.spec).attributes(this.attrs.capture());

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider;
@ -64,6 +65,9 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
private final AuthenticationManagerResolver<HttpServletRequest> authenticationManagerResolver;
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint();
private AuthenticationFailureHandler authenticationFailureHandler = (request, response, exception) -> {
@ -132,9 +136,9 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
try {
AuthenticationManager authenticationManager = this.authenticationManagerResolver.resolve(request);
Authentication authenticationResult = authenticationManager.authenticate(authenticationRequest);
SecurityContext context = SecurityContextHolder.createEmptyContext();
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authenticationResult);
SecurityContextHolder.setContext(context);
this.securityContextHolderStrategy.setContext(context);
this.securityContextRepository.saveContext(context, request, response);
if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authenticationResult));
@ -142,12 +146,23 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
filterChain.doFilter(request, response);
}
catch (AuthenticationException failed) {
SecurityContextHolder.clearContext();
this.securityContextHolderStrategy.clearContext();
this.logger.trace("Failed to process authentication request", failed);
this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed);
}
}
/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
*
* @since 5.8
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
/**
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
* authentication success. The default action is not to save the

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,6 +38,8 @@ import org.springframework.security.authentication.AuthenticationManagerResolver
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
import org.springframework.security.oauth2.server.resource.BearerTokenError;
@ -205,6 +207,18 @@ public class BearerTokenAuthenticationFilterTests {
verify(this.authenticationDetailsSource).buildDetails(this.request);
}
@Test
public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws ServletException, IOException {
given(this.bearerTokenResolver.resolve(this.request)).willReturn("token");
BearerTokenAuthenticationFilter filter = addMocks(
new BearerTokenAuthenticationFilter(this.authenticationManager));
SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
given(strategy.createEmptyContext()).willReturn(new SecurityContextImpl());
filter.setSecurityContextHolderStrategy(strategy);
filter.doFilter(this.request, this.response, this.filterChain);
verify(strategy).setContext(any());
}
@Test
public void setAuthenticationEntryPointWhenNullThenThrowsException() {
BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager);