diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/OAuth2ClientHttpRequestInterceptor.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/OAuth2ClientHttpRequestInterceptor.java index 2a6128d971..1fe4b704ce 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/OAuth2ClientHttpRequestInterceptor.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/OAuth2ClientHttpRequestInterceptor.java @@ -34,8 +34,6 @@ import org.springframework.lang.Nullable; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; @@ -121,16 +119,15 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque private final OAuth2AuthorizedClientManager authorizedClientManager; - private final ClientRegistrationIdResolver clientRegistrationIdResolver; + private ClientRegistrationIdResolver clientRegistrationIdResolver = new RequestAttributeClientRegistrationIdResolver(); + + private PrincipalResolver principalResolver = new SecurityContextHolderPrincipalResolver(); // @formatter:off private OAuth2AuthorizationFailureHandler authorizationFailureHandler = (clientRegistrationId, principal, attributes) -> { }; // @formatter:on - private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder - .getContextHolderStrategy(); - /** * Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided * parameters. @@ -138,23 +135,8 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque * manages the authorized client(s) */ public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager) { - this(authorizedClientManager, new RequestAttributeClientRegistrationIdResolver()); - } - - /** - * Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided - * parameters. - * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which - * manages the authorized client(s) - * @param clientRegistrationIdResolver the strategy for resolving a - * {@code clientRegistrationId} from the intercepted request - */ - public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager, - ClientRegistrationIdResolver clientRegistrationIdResolver) { Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); - Assert.notNull(clientRegistrationIdResolver, "clientRegistrationIdResolver cannot be null"); this.authorizedClientManager = authorizedClientManager; - this.clientRegistrationIdResolver = clientRegistrationIdResolver; } /** @@ -238,20 +220,31 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque } /** - * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use - * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. - * @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to - * use + * Sets the strategy for resolving a {@code clientRegistrationId} from an intercepted + * request. + * @param clientRegistrationIdResolver the strategy for resolving a + * {@code clientRegistrationId} from an intercepted request */ - public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { - Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); - this.securityContextHolderStrategy = securityContextHolderStrategy; + public void setClientRegistrationIdResolver(ClientRegistrationIdResolver clientRegistrationIdResolver) { + Assert.notNull(clientRegistrationIdResolver, "clientRegistrationIdResolver cannot be null"); + this.clientRegistrationIdResolver = clientRegistrationIdResolver; + } + + /** + * Sets the strategy for resolving a {@link Authentication principal} from an + * intercepted request. + * @param principalResolver the strategy for resolving a {@link Authentication + * principal} + */ + public void setPrincipalResolver(PrincipalResolver principalResolver) { + Assert.notNull(principalResolver, "principalResolver cannot be null"); + this.principalResolver = principalResolver; } @Override public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException { - Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication(); + Authentication principal = this.principalResolver.resolve(request); if (principal == null) { principal = ANONYMOUS_AUTHENTICATION; } @@ -378,4 +371,24 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque } + /** + * A strategy for resolving a {@link Authentication principal} from an intercepted + * request. + */ + @FunctionalInterface + public interface PrincipalResolver { + + /** + * Resolve the {@link Authentication principal} from the current request, which is + * used to obtain an {@link OAuth2AuthorizedClient}. + * @param request the intercepted request, containing HTTP method, URI, headers, + * and request attributes + * @return the {@link Authentication principal} to be used for resolving an + * {@link OAuth2AuthorizedClient}. + */ + @Nullable + Authentication resolve(HttpRequest request); + + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/RequestAttributePrincipalResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/RequestAttributePrincipalResolver.java new file mode 100644 index 0000000000..bbae6e86c0 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/RequestAttributePrincipalResolver.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.client; + +import java.util.Collections; +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.http.HttpRequest; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.util.Assert; + +/** + * A strategy for resolving a {@link Authentication principal} from an intercepted request + * using {@link ClientHttpRequest#getAttributes() attributes}. + * + * @author Steve Riesenberg + * @since 6.4 + */ +public class RequestAttributePrincipalResolver implements OAuth2ClientHttpRequestInterceptor.PrincipalResolver { + + private static final String PRINCIPAL_ATTR_NAME = RequestAttributePrincipalResolver.class.getName() + .concat(".principal"); + + @Override + public Authentication resolve(HttpRequest request) { + return (Authentication) request.getAttributes().get(PRINCIPAL_ATTR_NAME); + } + + /** + * Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the + * {@link Authentication principal} to be used to look up the + * {@link OAuth2AuthorizedClient}. + * @param principal the {@link Authentication principal} to be used to look up the + * {@link OAuth2AuthorizedClient} + * @return the {@link Consumer} to populate the attributes + */ + public static Consumer> principal(Authentication principal) { + Assert.notNull(principal, "principal cannot be null"); + return (attributes) -> attributes.put(PRINCIPAL_ATTR_NAME, principal); + } + + /** + * Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the + * {@link Authentication principal} to be used to look up the + * {@link OAuth2AuthorizedClient}. + * @param principalName the {@code principalName} to be used to look up the + * {@link OAuth2AuthorizedClient} + * @return the {@link Consumer} to populate the attributes + */ + public static Consumer> principal(String principalName) { + Assert.hasText(principalName, "principalName cannot be empty"); + Authentication principal = createAuthentication(principalName); + return (attributes) -> attributes.put(PRINCIPAL_ATTR_NAME, principal); + } + + private static Authentication createAuthentication(String principalName) { + return new AbstractAuthenticationToken(Collections.emptySet()) { + @Override + public Object getPrincipal() { + return principalName; + } + + @Override + public Object getCredentials() { + return null; + } + }; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/SecurityContextHolderPrincipalResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/SecurityContextHolderPrincipalResolver.java new file mode 100644 index 0000000000..8ffd5b204c --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/SecurityContextHolderPrincipalResolver.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.client; + +import org.springframework.http.HttpRequest; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; + +/** + * A strategy for resolving a {@link Authentication principal} from an intercepted request + * using the {@link SecurityContextHolder}. + * + * @author Steve Riesenberg + * @since 6.4 + */ +public class SecurityContextHolderPrincipalResolver implements OAuth2ClientHttpRequestInterceptor.PrincipalResolver { + + private final SecurityContextHolderStrategy securityContextHolderStrategy; + + /** + * Constructs a {@code SecurityContextHolderPrincipalResolver}. + */ + public SecurityContextHolderPrincipalResolver() { + this(SecurityContextHolder.getContextHolderStrategy()); + } + + /** + * Constructs a {@code SecurityContextHolderPrincipalResolver} using the provided + * parameters. + * @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to + * use for resolving the {@link Authentication principal} + */ + public SecurityContextHolderPrincipalResolver(SecurityContextHolderStrategy securityContextHolderStrategy) { + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + + @Override + public Authentication resolve(HttpRequest request) { + return this.securityContextHolderStrategy.getContext().getAuthentication(); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/client/OAuth2ClientHttpRequestInterceptorTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/client/OAuth2ClientHttpRequestInterceptorTests.java index f82f79034c..f19d81c277 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/client/OAuth2ClientHttpRequestInterceptorTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/client/OAuth2ClientHttpRequestInterceptorTests.java @@ -43,7 +43,6 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; @@ -110,15 +109,15 @@ public class OAuth2ClientHttpRequestInterceptorTests { @Mock private OAuth2AuthorizedClientRepository authorizedClientRepository; - @Mock - private SecurityContextHolderStrategy securityContextHolderStrategy; - @Mock private OAuth2AuthorizedClientService authorizedClientService; @Mock private OAuth2ClientHttpRequestInterceptor.ClientRegistrationIdResolver clientRegistrationIdResolver; + @Mock + private OAuth2ClientHttpRequestInterceptor.PrincipalResolver principalResolver; + @Captor private ArgumentCaptor authorizeRequestCaptor; @@ -167,13 +166,6 @@ public class OAuth2ClientHttpRequestInterceptorTests { .withMessage("authorizedClientManager cannot be null"); } - @Test - public void constructorWhenClientRegistrationIdResolverIsNullThenThrowsIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> new OAuth2ClientHttpRequestInterceptor(this.authorizedClientManager, null)) - .withMessage("clientRegistrationIdResolver cannot be null"); - } - @Test public void setAuthorizationFailureHandlerWhenNullThenThrowsIllegalArgumentException() { assertThatIllegalArgumentException() @@ -198,10 +190,16 @@ public class OAuth2ClientHttpRequestInterceptorTests { } @Test - public void setSecurityContextHolderStrategyWhenNullThenThrowsIllegalArgumentException() { + public void setClientRegistrationIdResolverWhenNullThenThrowsIllegalArgumentException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.requestInterceptor.setSecurityContextHolderStrategy(null)) - .withMessage("securityContextHolderStrategy cannot be null"); + .isThrownBy(() -> this.requestInterceptor.setClientRegistrationIdResolver(null)) + .withMessage("clientRegistrationIdResolver cannot be null"); + } + + @Test + public void setPrincipalResolverWhenNullThenThrowsIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.requestInterceptor.setPrincipalResolver(null)) + .withMessage("principalResolver cannot be null"); } @Test @@ -605,8 +603,7 @@ public class OAuth2ClientHttpRequestInterceptorTests { @Test public void interceptWhenCustomClientRegistrationIdResolverSetThenUsed() { - this.requestInterceptor = new OAuth2ClientHttpRequestInterceptor(this.authorizedClientManager, - this.clientRegistrationIdResolver); + this.requestInterceptor.setClientRegistrationIdResolver(this.clientRegistrationIdResolver); this.requestInterceptor.setAuthorizationFailureHandler(this.authorizationFailureHandler); given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class))) .willReturn(this.authorizedClient); @@ -625,7 +622,7 @@ public class OAuth2ClientHttpRequestInterceptorTests { this.server.verify(); verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture()); verify(this.clientRegistrationIdResolver).resolve(any(HttpRequest.class)); - verifyNoMoreInteractions(this.clientRegistrationIdResolver, this.authorizedClientManager); + verifyNoMoreInteractions(this.authorizedClientManager, this.clientRegistrationIdResolver); verifyNoInteractions(this.authorizationFailureHandler); OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue(); assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(clientRegistrationId); @@ -633,8 +630,8 @@ public class OAuth2ClientHttpRequestInterceptorTests { } @Test - public void interceptWhenCustomSecurityContextHolderStrategySetThenUsed() { - this.requestInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); + public void interceptWhenCustomPrincipalResolverSetThenUsed() { + this.requestInterceptor.setPrincipalResolver(this.principalResolver); given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class))) .willReturn(this.authorizedClient); @@ -642,14 +639,12 @@ public class OAuth2ClientHttpRequestInterceptorTests { this.server.expect(requestTo(REQUEST_URI)) .andExpect(hasAuthorizationHeader(this.authorizedClient.getAccessToken())) .andRespond(withApplicationJson()); - SecurityContext securityContext = new SecurityContextImpl(); - securityContext.setAuthentication(this.principal); - given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext); + given(this.principalResolver.resolve(any(HttpRequest.class))).willReturn(this.principal); performRequest(withClientRegistrationId()); this.server.verify(); verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture()); - verify(this.securityContextHolderStrategy).getContext(); - verifyNoMoreInteractions(this.authorizedClientManager, this.securityContextHolderStrategy); + verify(this.principalResolver).resolve(any(HttpRequest.class)); + verifyNoMoreInteractions(this.authorizedClientManager, this.principalResolver); OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue(); assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal);