From c208410a9191ce4060259299946b0ccb25a8dcf9 Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Thu, 26 Feb 2026 06:18:55 -0700 Subject: [PATCH] Polish Jwt Authentication Converter - Replace conditional logic with adapter class - Added tests Issue gh-6237 Signed-off-by: Josh Cummings <3627351+jzheaux@users.noreply.github.com> --- .../JwtAuthenticationConverter.java | 60 +++++++++++++------ .../JwtAuthenticationToken.java | 6 +- ...JwtBearerTokenAuthenticationConverter.java | 27 +++++---- .../JwtAuthenticationConverterTests.java | 21 +++++++ .../JwtAuthenticationTokenTests.java | 2 +- ...arerTokenAuthenticationConverterTests.java | 21 +++++++ 6 files changed, 104 insertions(+), 33 deletions(-) diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverter.java index 29a00ce1ed..2b778dec13 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverter.java @@ -18,6 +18,8 @@ package org.springframework.security.oauth2.server.resource.authentication; import java.util.Collection; import java.util.HashSet; +import java.util.List; +import java.util.Map; import org.springframework.core.convert.converter.Converter; import org.springframework.security.authentication.AbstractAuthenticationToken; @@ -40,34 +42,26 @@ public class JwtAuthenticationConverter implements Converter jwtPrincipalConverter; - private Converter> jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); + private Converter jwtPrincipalConverter = JwtAuthenticatedPrincipal::new; - private String principalClaimName = JwtClaimNames.SUB; + private Converter> jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); @Override public final AbstractAuthenticationToken convert(Jwt jwt) { Collection authorities = new HashSet<>(this.jwtGrantedAuthoritiesConverter.convert(jwt)); authorities.add(FactorGrantedAuthority.fromAuthority(AUTHORITY)); - - if (this.jwtPrincipalConverter == null) { - String principalClaimValue = jwt.getClaimAsString(this.principalClaimName); - return new JwtAuthenticationToken(jwt, authorities, principalClaimValue); - } else { - OAuth2AuthenticatedPrincipal principal = this.jwtPrincipalConverter.convert(jwt); - authorities.addAll(principal.getAuthorities()); - return new JwtAuthenticationToken(jwt, principal, authorities); - } + OAuth2AuthenticatedPrincipal principal = this.jwtPrincipalConverter.convert(jwt); + authorities.addAll(principal.getAuthorities()); + return new JwtAuthenticationToken(jwt, principal, authorities); } /** - * Sets the {@link Converter Converter<Jwt, Collection<OAuth2AuthenticatedPrincipal>>} - * to use. + * Sets the {@link Converter Converter<Jwt, OAuth2AuthenticatedPrincipal>} to + * use. * @param jwtPrincipalConverter The converter - * @since 6.5.0 + * @since 7.1 */ - public void setJwtPrincipalConverter( - Converter jwtPrincipalConverter) { + public void setJwtPrincipalConverter(Converter jwtPrincipalConverter) { Assert.notNull(jwtPrincipalConverter, "jwtPrincipalConverter cannot be null"); this.jwtPrincipalConverter = jwtPrincipalConverter; } @@ -92,7 +86,37 @@ public class JwtAuthenticationConverter implements Converter new JwtAuthenticatedPrincipal(jwt, principalClaimName); + } + + private static final class JwtAuthenticatedPrincipal extends Jwt implements OAuth2AuthenticatedPrincipal { + + private final String principalClaimName; + + JwtAuthenticatedPrincipal(Jwt jwt) { + this(jwt, JwtClaimNames.SUB); + } + + JwtAuthenticatedPrincipal(Jwt jwt, String principalClaimName) { + super(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getHeaders(), jwt.getClaims()); + this.principalClaimName = principalClaimName; + } + + @Override + public Map getAttributes() { + return getClaims(); + } + + @Override + public Collection getAuthorities() { + return List.of(); + } + + @Override + public String getName() { + return getClaimAsString(this.principalClaimName); + } + } } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationToken.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationToken.java index 65204c4065..5801c043b4 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationToken.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationToken.java @@ -21,8 +21,8 @@ import java.util.Map; import org.jspecify.annotations.Nullable; -import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticatedPrincipal; +import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.Transient; import org.springframework.security.oauth2.jwt.Jwt; @@ -87,13 +87,15 @@ public class JwtAuthenticationToken extends AbstractOAuth2TokenAuthenticationTok * @param jwt the JWT * @param principal the principal * @param authorities the authorities assigned to the JWT + * @since 7.1 */ public JwtAuthenticationToken(Jwt jwt, Object principal, Collection authorities) { super(jwt, principal, jwt, authorities); this.setAuthenticated(true); if (principal instanceof AuthenticatedPrincipal) { this.name = ((AuthenticatedPrincipal) principal).getName(); - } else { + } + else { this.name = jwt.getSubject(); } } diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverter.java index e657f532ed..64cb8b905c 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverter.java @@ -17,7 +17,6 @@ package org.springframework.security.oauth2.server.resource.authentication; import java.util.Collection; -import java.util.Map; import org.springframework.core.convert.converter.Converter; import org.springframework.security.authentication.AbstractAuthenticationToken; @@ -47,29 +46,33 @@ import org.springframework.util.Assert; */ public final class JwtBearerTokenAuthenticationConverter implements Converter { - private final JwtAuthenticationConverter jwtAuthenticationConverter = new JwtAuthenticationConverter(); + private Converter> jwtGrantedAuthoritiesConverter = new JwtGrantedAuthoritiesConverter(); + + private Converter jwtPrincipalConverter = ( + jwt) -> new DefaultOAuth2AuthenticatedPrincipal(jwt.getClaims(), + this.jwtGrantedAuthoritiesConverter.convert(jwt)); @Override public AbstractAuthenticationToken convert(Jwt jwt) { OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt()); - Map attributes = jwt.getClaims(); - AbstractAuthenticationToken token = this.jwtAuthenticationConverter.convert(jwt); - Collection authorities = token.getAuthorities(); - OAuth2AuthenticatedPrincipal principal = new DefaultOAuth2AuthenticatedPrincipal(attributes, authorities); + Collection authorities = this.jwtGrantedAuthoritiesConverter.convert(jwt); + OAuth2AuthenticatedPrincipal principal = this.jwtPrincipalConverter.convert(jwt); return new BearerTokenAuthentication(principal, accessToken, authorities); } /** - * Sets the {@link Converter Converter<Jwt, Collection<OAuth2AuthenticatedPrincipal>>} - * to use. + * Sets the {@link Converter Converter<Jwt, OAuth2AuthenticatedPrincipal>} to + * use. + *

+ * By default, constructs a {@link DefaultOAuth2AuthenticatedPrincipal} based on the + * claims and authorities derived from the {@link Jwt}. * @param jwtPrincipalConverter The converter - * @since 6.5.0 + * @since 7.1 */ - public void setJwtPrincipalConverter( - Converter jwtPrincipalConverter) { + public void setJwtPrincipalConverter(Converter jwtPrincipalConverter) { Assert.notNull(jwtPrincipalConverter, "jwtPrincipalConverter cannot be null"); - this.jwtAuthenticationConverter.setJwtPrincipalConverter(jwtPrincipalConverter); + this.jwtPrincipalConverter = jwtPrincipalConverter; } } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverterTests.java index 9e085f15e1..cc08a53afd 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationConverterTests.java @@ -18,6 +18,8 @@ package org.springframework.security.oauth2.server.resource.authentication; import java.util.Arrays; import java.util.Collection; +import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Test; @@ -28,6 +30,8 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.FactorGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal; +import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.TestJwts; @@ -119,4 +123,21 @@ public class JwtAuthenticationConverterTests { SecurityAssertions.assertThat(result).hasAuthority(FactorGrantedAuthority.BEARER_AUTHORITY); } + @Test + public void whenSettingNullJwtPrincipalConverter() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.jwtAuthenticationConverter.setJwtPrincipalConverter(null)) + .withMessage("jwtPrincipalConverter cannot be null"); + } + + @Test + public void convertWhenJwtPrincipalConverterSetThenCustomPrincipalUsed() { + OAuth2AuthenticatedPrincipal customPrincipal = new DefaultOAuth2AuthenticatedPrincipal("custom-name", + Map.of("sub", "custom-name"), List.of()); + this.jwtAuthenticationConverter.setJwtPrincipalConverter((jwt) -> customPrincipal); + Jwt jwt = TestJwts.jwt().build(); + AbstractAuthenticationToken authentication = this.jwtAuthenticationConverter.convert(jwt); + assertThat(authentication.getName()).isEqualTo("custom-name"); + } + } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationTokenTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationTokenTests.java index 2114c6f87b..4105dbf69a 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationTokenTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationTokenTests.java @@ -111,7 +111,7 @@ public class JwtAuthenticationTokenTests { public void getNameWhenConstructedWithNoSubjectThenReturnsNull() { Collection authorities = AuthorityUtils.createAuthorityList("test"); Jwt jwt = builder().claim("claim", "value").build(); - assertThat(new JwtAuthenticationToken(jwt, authorities, null).getName()).isNull(); + assertThat(new JwtAuthenticationToken(jwt, authorities, (String) null).getName()).isNull(); assertThat(new JwtAuthenticationToken(jwt, authorities).getName()).isNull(); assertThat(new JwtAuthenticationToken(jwt).getName()).isNull(); } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverterTests.java index 2edbf0f754..964d45ed97 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtBearerTokenAuthenticationConverterTests.java @@ -17,6 +17,8 @@ package org.springframework.security.oauth2.server.resource.authentication; import java.util.Arrays; +import java.util.List; +import java.util.Map; import java.util.function.Predicate; import org.junit.jupiter.api.Test; @@ -24,6 +26,8 @@ import org.junit.jupiter.api.Test; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.SecurityAssertions; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal; +import org.springframework.security.oauth2.core.OAuth2AuthenticatedPrincipal; import org.springframework.security.oauth2.jwt.Jwt; import static org.assertj.core.api.Assertions.assertThat; @@ -81,6 +85,23 @@ public class JwtBearerTokenAuthenticationConverterTests { SecurityAssertions.assertThat(bearerToken).hasAuthorities("SCOPE_message:read", "SCOPE_message:write"); } + @Test + public void convertWhenJwtPrincipalConverterSetThenCustomPrincipalUsed() { + OAuth2AuthenticatedPrincipal customPrincipal = new DefaultOAuth2AuthenticatedPrincipal("custom-name", + Map.of("claim", "value"), List.of()); + this.converter.setJwtPrincipalConverter((jwt) -> customPrincipal); + // @formatter:off + Jwt jwt = Jwt.withTokenValue("token-value") + .claim("claim", "value") + .header("header", "value") + .build(); + // @formatter:on + AbstractAuthenticationToken token = this.converter.convert(jwt); + assertThat(token).isInstanceOf(BearerTokenAuthentication.class); + BearerTokenAuthentication bearerToken = (BearerTokenAuthentication) token; + assertThat(bearerToken.getName()).isEqualTo("custom-name"); + } + static Predicate isScope() { return (a) -> a.getAuthority().startsWith("SCOPE_"); }