Customize the strategy for resolving the principal

Closes gh-15826
This commit is contained in:
Steve Riesenberg 2024-09-19 10:10:54 -05:00
parent 50cb051c86
commit e11c188122
No known key found for this signature in database
GPG Key ID: 3D0169B18AB8F0A9
4 changed files with 206 additions and 53 deletions

View File

@ -34,8 +34,6 @@ import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
@ -121,16 +119,15 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
private final OAuth2AuthorizedClientManager authorizedClientManager; private final OAuth2AuthorizedClientManager authorizedClientManager;
private final ClientRegistrationIdResolver clientRegistrationIdResolver; private ClientRegistrationIdResolver clientRegistrationIdResolver = new RequestAttributeClientRegistrationIdResolver();
private PrincipalResolver principalResolver = new SecurityContextHolderPrincipalResolver();
// @formatter:off // @formatter:off
private OAuth2AuthorizationFailureHandler authorizationFailureHandler = private OAuth2AuthorizationFailureHandler authorizationFailureHandler =
(clientRegistrationId, principal, attributes) -> { }; (clientRegistrationId, principal, attributes) -> { };
// @formatter:on // @formatter:on
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
/** /**
* Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided * Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided
* parameters. * parameters.
@ -138,23 +135,8 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
* manages the authorized client(s) * manages the authorized client(s)
*/ */
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager) { public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager) {
this(authorizedClientManager, new RequestAttributeClientRegistrationIdResolver());
}
/**
* Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided
* parameters.
* @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
* manages the authorized client(s)
* @param clientRegistrationIdResolver the strategy for resolving a
* {@code clientRegistrationId} from the intercepted request
*/
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager,
ClientRegistrationIdResolver clientRegistrationIdResolver) {
Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
Assert.notNull(clientRegistrationIdResolver, "clientRegistrationIdResolver cannot be null");
this.authorizedClientManager = authorizedClientManager; this.authorizedClientManager = authorizedClientManager;
this.clientRegistrationIdResolver = clientRegistrationIdResolver;
} }
/** /**
@ -238,20 +220,31 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
} }
/** /**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use * Sets the strategy for resolving a {@code clientRegistrationId} from an intercepted
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. * request.
* @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to * @param clientRegistrationIdResolver the strategy for resolving a
* use * {@code clientRegistrationId} from an intercepted request
*/ */
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { public void setClientRegistrationIdResolver(ClientRegistrationIdResolver clientRegistrationIdResolver) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); Assert.notNull(clientRegistrationIdResolver, "clientRegistrationIdResolver cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy; this.clientRegistrationIdResolver = clientRegistrationIdResolver;
}
/**
* Sets the strategy for resolving a {@link Authentication principal} from an
* intercepted request.
* @param principalResolver the strategy for resolving a {@link Authentication
* principal}
*/
public void setPrincipalResolver(PrincipalResolver principalResolver) {
Assert.notNull(principalResolver, "principalResolver cannot be null");
this.principalResolver = principalResolver;
} }
@Override @Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
throws IOException { throws IOException {
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication(); Authentication principal = this.principalResolver.resolve(request);
if (principal == null) { if (principal == null) {
principal = ANONYMOUS_AUTHENTICATION; principal = ANONYMOUS_AUTHENTICATION;
} }
@ -378,4 +371,24 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
} }
/**
* A strategy for resolving a {@link Authentication principal} from an intercepted
* request.
*/
@FunctionalInterface
public interface PrincipalResolver {
/**
* Resolve the {@link Authentication principal} from the current request, which is
* used to obtain an {@link OAuth2AuthorizedClient}.
* @param request the intercepted request, containing HTTP method, URI, headers,
* and request attributes
* @return the {@link Authentication principal} to be used for resolving an
* {@link OAuth2AuthorizedClient}.
*/
@Nullable
Authentication resolve(HttpRequest request);
}
} }

View File

@ -0,0 +1,88 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.client;
import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.http.HttpRequest;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.util.Assert;
/**
* A strategy for resolving a {@link Authentication principal} from an intercepted request
* using {@link ClientHttpRequest#getAttributes() attributes}.
*
* @author Steve Riesenberg
* @since 6.4
*/
public class RequestAttributePrincipalResolver implements OAuth2ClientHttpRequestInterceptor.PrincipalResolver {
private static final String PRINCIPAL_ATTR_NAME = RequestAttributePrincipalResolver.class.getName()
.concat(".principal");
@Override
public Authentication resolve(HttpRequest request) {
return (Authentication) request.getAttributes().get(PRINCIPAL_ATTR_NAME);
}
/**
* Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
* {@link Authentication principal} to be used to look up the
* {@link OAuth2AuthorizedClient}.
* @param principal the {@link Authentication principal} to be used to look up the
* {@link OAuth2AuthorizedClient}
* @return the {@link Consumer} to populate the attributes
*/
public static Consumer<Map<String, Object>> principal(Authentication principal) {
Assert.notNull(principal, "principal cannot be null");
return (attributes) -> attributes.put(PRINCIPAL_ATTR_NAME, principal);
}
/**
* Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
* {@link Authentication principal} to be used to look up the
* {@link OAuth2AuthorizedClient}.
* @param principalName the {@code principalName} to be used to look up the
* {@link OAuth2AuthorizedClient}
* @return the {@link Consumer} to populate the attributes
*/
public static Consumer<Map<String, Object>> principal(String principalName) {
Assert.hasText(principalName, "principalName cannot be empty");
Authentication principal = createAuthentication(principalName);
return (attributes) -> attributes.put(PRINCIPAL_ATTR_NAME, principal);
}
private static Authentication createAuthentication(String principalName) {
return new AbstractAuthenticationToken(Collections.emptySet()) {
@Override
public Object getPrincipal() {
return principalName;
}
@Override
public Object getCredentials() {
return null;
}
};
}
}

View File

@ -0,0 +1,57 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.client;
import org.springframework.http.HttpRequest;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
/**
* A strategy for resolving a {@link Authentication principal} from an intercepted request
* using the {@link SecurityContextHolder}.
*
* @author Steve Riesenberg
* @since 6.4
*/
public class SecurityContextHolderPrincipalResolver implements OAuth2ClientHttpRequestInterceptor.PrincipalResolver {
private final SecurityContextHolderStrategy securityContextHolderStrategy;
/**
* Constructs a {@code SecurityContextHolderPrincipalResolver}.
*/
public SecurityContextHolderPrincipalResolver() {
this(SecurityContextHolder.getContextHolderStrategy());
}
/**
* Constructs a {@code SecurityContextHolderPrincipalResolver} using the provided
* parameters.
* @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to
* use for resolving the {@link Authentication principal}
*/
public SecurityContextHolderPrincipalResolver(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.securityContextHolderStrategy = securityContextHolderStrategy;
}
@Override
public Authentication resolve(HttpRequest request) {
return this.securityContextHolderStrategy.getContext().getAuthentication();
}
}

View File

@ -43,7 +43,6 @@ import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
@ -110,15 +109,15 @@ public class OAuth2ClientHttpRequestInterceptorTests {
@Mock @Mock
private OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientRepository authorizedClientRepository;
@Mock
private SecurityContextHolderStrategy securityContextHolderStrategy;
@Mock @Mock
private OAuth2AuthorizedClientService authorizedClientService; private OAuth2AuthorizedClientService authorizedClientService;
@Mock @Mock
private OAuth2ClientHttpRequestInterceptor.ClientRegistrationIdResolver clientRegistrationIdResolver; private OAuth2ClientHttpRequestInterceptor.ClientRegistrationIdResolver clientRegistrationIdResolver;
@Mock
private OAuth2ClientHttpRequestInterceptor.PrincipalResolver principalResolver;
@Captor @Captor
private ArgumentCaptor<OAuth2AuthorizeRequest> authorizeRequestCaptor; private ArgumentCaptor<OAuth2AuthorizeRequest> authorizeRequestCaptor;
@ -167,13 +166,6 @@ public class OAuth2ClientHttpRequestInterceptorTests {
.withMessage("authorizedClientManager cannot be null"); .withMessage("authorizedClientManager cannot be null");
} }
@Test
public void constructorWhenClientRegistrationIdResolverIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new OAuth2ClientHttpRequestInterceptor(this.authorizedClientManager, null))
.withMessage("clientRegistrationIdResolver cannot be null");
}
@Test @Test
public void setAuthorizationFailureHandlerWhenNullThenThrowsIllegalArgumentException() { public void setAuthorizationFailureHandlerWhenNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
@ -198,10 +190,16 @@ public class OAuth2ClientHttpRequestInterceptorTests {
} }
@Test @Test
public void setSecurityContextHolderStrategyWhenNullThenThrowsIllegalArgumentException() { public void setClientRegistrationIdResolverWhenNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.requestInterceptor.setSecurityContextHolderStrategy(null)) .isThrownBy(() -> this.requestInterceptor.setClientRegistrationIdResolver(null))
.withMessage("securityContextHolderStrategy cannot be null"); .withMessage("clientRegistrationIdResolver cannot be null");
}
@Test
public void setPrincipalResolverWhenNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.requestInterceptor.setPrincipalResolver(null))
.withMessage("principalResolver cannot be null");
} }
@Test @Test
@ -605,8 +603,7 @@ public class OAuth2ClientHttpRequestInterceptorTests {
@Test @Test
public void interceptWhenCustomClientRegistrationIdResolverSetThenUsed() { public void interceptWhenCustomClientRegistrationIdResolverSetThenUsed() {
this.requestInterceptor = new OAuth2ClientHttpRequestInterceptor(this.authorizedClientManager, this.requestInterceptor.setClientRegistrationIdResolver(this.clientRegistrationIdResolver);
this.clientRegistrationIdResolver);
this.requestInterceptor.setAuthorizationFailureHandler(this.authorizationFailureHandler); this.requestInterceptor.setAuthorizationFailureHandler(this.authorizationFailureHandler);
given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class))) given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class)))
.willReturn(this.authorizedClient); .willReturn(this.authorizedClient);
@ -625,7 +622,7 @@ public class OAuth2ClientHttpRequestInterceptorTests {
this.server.verify(); this.server.verify();
verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture()); verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture());
verify(this.clientRegistrationIdResolver).resolve(any(HttpRequest.class)); verify(this.clientRegistrationIdResolver).resolve(any(HttpRequest.class));
verifyNoMoreInteractions(this.clientRegistrationIdResolver, this.authorizedClientManager); verifyNoMoreInteractions(this.authorizedClientManager, this.clientRegistrationIdResolver);
verifyNoInteractions(this.authorizationFailureHandler); verifyNoInteractions(this.authorizationFailureHandler);
OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue(); OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue();
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(clientRegistrationId); assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(clientRegistrationId);
@ -633,8 +630,8 @@ public class OAuth2ClientHttpRequestInterceptorTests {
} }
@Test @Test
public void interceptWhenCustomSecurityContextHolderStrategySetThenUsed() { public void interceptWhenCustomPrincipalResolverSetThenUsed() {
this.requestInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); this.requestInterceptor.setPrincipalResolver(this.principalResolver);
given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class))) given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class)))
.willReturn(this.authorizedClient); .willReturn(this.authorizedClient);
@ -642,14 +639,12 @@ public class OAuth2ClientHttpRequestInterceptorTests {
this.server.expect(requestTo(REQUEST_URI)) this.server.expect(requestTo(REQUEST_URI))
.andExpect(hasAuthorizationHeader(this.authorizedClient.getAccessToken())) .andExpect(hasAuthorizationHeader(this.authorizedClient.getAccessToken()))
.andRespond(withApplicationJson()); .andRespond(withApplicationJson());
SecurityContext securityContext = new SecurityContextImpl(); given(this.principalResolver.resolve(any(HttpRequest.class))).willReturn(this.principal);
securityContext.setAuthentication(this.principal);
given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext);
performRequest(withClientRegistrationId()); performRequest(withClientRegistrationId());
this.server.verify(); this.server.verify();
verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture()); verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture());
verify(this.securityContextHolderStrategy).getContext(); verify(this.principalResolver).resolve(any(HttpRequest.class));
verifyNoMoreInteractions(this.authorizedClientManager, this.securityContextHolderStrategy); verifyNoMoreInteractions(this.authorizedClientManager, this.principalResolver);
OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue(); OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue();
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId());
assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal);