Add PrincipalResolver to ExchangeFilterFunctions

Closes gh-16284

Signed-off-by: Evgeniy Cheban <mister.cheban@gmail.com>
This commit is contained in:
Evgeniy Cheban 2026-03-13 05:32:36 +02:00 committed by Joe Grandja
parent aa35db5aad
commit 8f2a5a7b6e
4 changed files with 169 additions and 5 deletions

View File

@ -119,8 +119,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
"anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER"));
private final Mono<Authentication> currentAuthenticationMono = ReactiveSecurityContextHolder.getContext()
.flatMap((ctx) -> Mono.justOrEmpty(ctx.getAuthentication()))
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
.mapNotNull(SecurityContext::getAuthentication);
// @formatter:off
private final Mono<String> clientRegistrationIdMono = this.currentAuthenticationMono
@ -145,6 +144,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
private PrincipalResolver principalResolver = (request) -> this.currentAuthenticationMono;
/**
* Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the
* provided parameters.
@ -332,6 +333,15 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
return this.principalResolver.resolve(request)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN)
.flatMap((authentication) -> doFilter(request, next)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)));
// @formatter:on
}
private Mono<ClientResponse> doFilter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
return authorizedClient(request)
.map((authorizedClient) -> bearer(request, authorizedClient))
@ -483,6 +493,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
this.serverSecurityContextRepository = serverSecurityContextRepository;
}
/**
* Sets the strategy for resolving a {@link Mono} of the {@link Authentication
* principal} from an intercepted request.
* @param principalResolver the strategy for resolving a {@link Mono} of the
* {@link Authentication principal}
* @since 7.1
*/
public void setPrincipalResolver(PrincipalResolver principalResolver) {
Assert.notNull(principalResolver, "principalResolver cannot be null");
this.principalResolver = principalResolver;
}
@FunctionalInterface
private interface ClientResponseHandler {
@ -490,6 +512,27 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
}
/**
* A strategy for resolving a {@link Mono} of the {@link Authentication principal}
* from an intercepted request.
*
* @since 7.1
*/
@FunctionalInterface
public interface PrincipalResolver {
/**
* Resolve a {@link Mono} of the {@link Authentication principal} from the current
* request, which is used to obtain an {@link OAuth2AuthorizedClient}.
* @param request the intercepted request, containing HTTP method, URI, headers,
* and request attributes
* @return the {@link Mono} of the {@link Authentication principal} to be used for
* resolving an {@link OAuth2AuthorizedClient}
*/
Mono<Authentication> resolve(ClientRequest request);
}
/**
* Forwards authentication and authorization failures to a
* {@link ReactiveOAuth2AuthorizationFailureHandler}.

View File

@ -123,6 +123,7 @@ import org.springframework.web.reactive.function.client.WebClientResponseExcepti
* @author Rob Winch
* @author Joe Grandja
* @author Roman Matiushchenko
* @author Evgeniy Cheban
* @since 5.1
* @see OAuth2AuthorizedClientManager
* @see DefaultOAuth2AuthorizedClientManager
@ -154,6 +155,13 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private @Nullable OAuth2AuthorizedClientManager authorizedClientManager;
/*
* For consistency, the default implementation resolves a principal from request
* attributes. Request attributes are populated from Reactor context which is enriched
* in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber
*/
private PrincipalResolver principalResolver = (request) -> getAuthentication(request.attributes());
private boolean defaultOAuth2AuthorizedClient;
private @Nullable String defaultClientRegistrationId;
@ -375,6 +383,18 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
}
/**
* Sets the strategy for resolving a {@link Authentication principal} from an
* intercepted request.
* @param principalResolver the strategy for resolving a {@link Authentication
* principal}
* @since 7.1
*/
public void setPrincipalResolver(PrincipalResolver principalResolver) {
Assert.notNull(principalResolver, "principalResolver cannot be null");
this.principalResolver = principalResolver;
}
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
@ -471,7 +491,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
if (clientRegistrationId == null) {
clientRegistrationId = this.defaultClientRegistrationId;
}
Authentication authentication = getAuthentication(attrs);
Authentication authentication = this.principalResolver.resolve(request);
if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient
&& authentication instanceof OAuth2AuthenticationToken) {
clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
@ -485,7 +505,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
return Mono.empty();
}
Map<String, Object> attrs = request.attributes();
Authentication authentication = getAuthentication(attrs);
Authentication authentication = this.principalResolver.resolve(request);
if (authentication == null) {
authentication = ANONYMOUS_AUTHENTICATION;
}
@ -512,7 +532,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
return Mono.empty();
}
Map<String, Object> attrs = request.attributes();
Authentication authentication = getAuthentication(attrs);
Authentication authentication = this.principalResolver.resolve(request);
if (authentication == null) {
authentication = createAuthentication(authorizedClient.getPrincipalName());
}
@ -587,6 +607,27 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
};
}
/**
* A strategy for resolving a {@link Authentication principal} from an intercepted
* request.
*
* @since 7.1
*/
@FunctionalInterface
public interface PrincipalResolver {
/**
* Resolve a {@link Authentication principal} from the current request, which is
* used to obtain an {@link OAuth2AuthorizedClient}.
* @param request the intercepted request, containing HTTP method, URI, headers,
* and request attributes
* @return the {@link Mono} of the {@link Authentication principal} to be used for
* resolving an {@link OAuth2AuthorizedClient}
*/
@Nullable Authentication resolve(ClientRequest request);
}
@FunctionalInterface
private interface ClientResponseHandler {

View File

@ -218,6 +218,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.setServerSecurityContextRepository(null));
}
@Test
public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
.setPrincipalResolver(null));
}
@Test
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
@ -791,6 +798,38 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(getBody(request0)).isEmpty();
}
@Test
public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() {
this.function.setDefaultOAuth2AuthorizedClient(true);
OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"),
Collections.singletonMap("user", "rob"), "user");
OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, user.getAuthorities(),
"initial-registration-id");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(),
this.registration.getRegistrationId());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken);
given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(),
authentication, this.serverWebExchange))
.willReturn(Mono.just(authorizedClient));
final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.build();
this.function.setPrincipalResolver((request) -> Mono.just(authentication));
this.function.filter(clientRequest, this.exchange)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(initialAuthentication))
.contextWrite(serverWebExchange())
.block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request0 = requests.get(0);
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request0.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request0)).isEmpty();
verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(),
authentication, this.serverWebExchange);
}
@Test
public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();

View File

@ -125,6 +125,7 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat
/**
* @author Rob Winch
* @author Evgeniy Cheban
* @since 5.1
*/
@ExtendWith(MockitoExtension.class)
@ -217,6 +218,13 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
.isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(null));
}
@Test
public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
.setPrincipalResolver(null));
}
@Test
public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() {
Map<String, Object> attrs = getDefaultRequestAttributes();
@ -620,6 +628,39 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(getBody(request)).isEmpty();
}
@Test
public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() {
this.function.setDefaultOAuth2AuthorizedClient(true);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, authorities,
"initial-registration-id");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, authorities,
this.registration.getRegistrationId());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken);
given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(),
initialAuthentication, servletRequest))
.willReturn(authorizedClient);
final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.build();
this.function.setPrincipalResolver((request) -> authentication);
this.function.filter(clientRequest, this.exchange)
.contextWrite(context(servletRequest, servletResponse, initialAuthentication))
.block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(),
authentication, servletRequest);
}
@Test
public void filterWhenUnauthorizedThenInvokeFailureHandler() {
assertHttpStatusInvokesFailureHandler(HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);