Allow Jwt assertion to be resolved

Closes gh-9812
This commit is contained in:
Joe Grandja 2022-01-07 13:23:02 -05:00
parent 6c5fd38a3f
commit 525f40490c
6 changed files with 131 additions and 15 deletions

View File

@ -1098,3 +1098,9 @@ class OAuth2ResourceServerController {
}
----
====
[NOTE]
`JwtBearerReactiveOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example.
[TIP]
If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerReactiveOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function<OAuth2AuthorizationContext, Mono<Jwt>>`.

View File

@ -1352,3 +1352,9 @@ class OAuth2ResourceServerController {
}
----
====
[NOTE]
`JwtBearerOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example.
[TIP]
If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function<OAuth2AuthorizationContext, Jwt>`.

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.function.Function;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.client.endpoint.DefaultJwtBearerTokenResponseClient;
@ -45,6 +46,8 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
private OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new DefaultJwtBearerTokenResponseClient();
private Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver = this::resolveJwtAssertion;
private Duration clockSkew = Duration.ofSeconds(60);
private Clock clock = Clock.systemUTC();
@ -75,10 +78,10 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
// need for re-authorization
return null;
}
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
Jwt jwt = this.jwtAssertionResolver.apply(context);
if (jwt == null) {
return null;
}
Jwt jwt = (Jwt) context.getPrincipal().getPrincipal();
// As per spec, in section 4.1 Using Assertions as Authorization Grants
// https://tools.ietf.org/html/rfc7521#section-4.1
//
@ -97,6 +100,13 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
tokenResponse.getAccessToken());
}
private Jwt resolveJwtAssertion(OAuth2AuthorizationContext context) {
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
return null;
}
return (Jwt) context.getPrincipal().getPrincipal();
}
private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration,
JwtBearerGrantRequest jwtBearerGrantRequest) {
try {
@ -123,6 +133,17 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
this.accessTokenResponseClient = accessTokenResponseClient;
}
/**
* Sets the resolver used for resolving the {@link Jwt} assertion.
* @param jwtAssertionResolver the resolver used for resolving the {@link Jwt}
* assertion
* @since 5.7
*/
public void setJwtAssertionResolver(Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver) {
Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null");
this.jwtAssertionResolver = jwtAssertionResolver;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.function.Function;
import reactor.core.publisher.Mono;
@ -45,6 +46,8 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re
private ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new WebClientReactiveJwtBearerTokenResponseClient();
private Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver = this::resolveJwtAssertion;
private Duration clockSkew = Duration.ofSeconds(60);
private Clock clock = Clock.systemUTC();
@ -74,10 +77,7 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re
// need for re-authorization
return Mono.empty();
}
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
return Mono.empty();
}
Jwt jwt = (Jwt) context.getPrincipal().getPrincipal();
// As per spec, in section 4.1 Using Assertions as Authorization Grants
// https://tools.ietf.org/html/rfc7521#section-4.1
//
@ -90,13 +90,26 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re
// issued with a reasonably short lifetime. Clients can refresh an
// expired access token by requesting a new one using the same
// assertion, if it is still valid, or with a new assertion.
return Mono.just(new JwtBearerGrantRequest(clientRegistration, jwt))
// @formatter:off
return this.jwtAssertionResolver.apply(context)
.map((jwt) -> new JwtBearerGrantRequest(clientRegistration, jwt))
.flatMap(this.accessTokenResponseClient::getTokenResponse)
.onErrorMap(OAuth2AuthorizationException.class,
(ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(),
ex))
.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken()));
// @formatter:on
}
private Mono<Jwt> resolveJwtAssertion(OAuth2AuthorizationContext context) {
// @formatter:off
return Mono.just(context)
.map((ctx) -> ctx.getPrincipal().getPrincipal())
.filter((principal) -> principal instanceof Jwt)
.cast(Jwt.class);
// @formatter:on
}
private boolean hasTokenExpired(OAuth2Token token) {
@ -115,6 +128,17 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re
this.accessTokenResponseClient = accessTokenResponseClient;
}
/**
* Sets the resolver used for resolving the {@link Jwt} assertion.
* @param jwtAssertionResolver the resolver used for resolving the {@link Jwt}
* assertion
* @since 5.7
*/
public void setJwtAssertionResolver(Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver) {
Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null");
this.jwtAssertionResolver = jwtAssertionResolver;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client;
import java.time.Duration;
import java.time.Instant;
import java.util.function.Function;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -42,6 +43,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link JwtBearerOAuth2AuthorizedClientProvider}.
@ -87,6 +89,13 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
.withMessage("accessTokenResponseClient cannot be null");
}
@Test
public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null))
.withMessage("jwtAssertionResolver cannot be null");
}
@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
@ -198,7 +207,7 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
}
@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
@ -209,7 +218,7 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
}
@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() {
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
@ -224,4 +233,25 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}
@Test
public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() {
Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver = mock(Function.class);
given(jwtAssertionResolver.apply(any())).willReturn(this.jwtAssertion);
this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
.principal(principal)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
verify(jwtAssertionResolver).apply(any());
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.function.Function;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -93,6 +94,13 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests {
.withMessage("accessTokenResponseClient cannot be null");
}
@Test
public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null))
.withMessage("jwtAssertionResolver cannot be null");
}
@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
@ -222,7 +230,7 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests {
}
@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
@ -251,7 +259,7 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests {
}
@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() {
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
@ -266,4 +274,25 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests {
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}
@Test
public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() {
Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver = mock(Function.class);
given(jwtAssertionResolver.apply(any())).willReturn(Mono.just(this.jwtAssertion));
this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
.principal(principal)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
verify(jwtAssertionResolver).apply(any());
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}
}