diff --git a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutReactiveAuthenticationManager.java b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutReactiveAuthenticationManager.java index f36e48a2cb..c522910a5c 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutReactiveAuthenticationManager.java +++ b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutReactiveAuthenticationManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ package org.springframework.security.config.web.server; +import com.nimbusds.jose.JOSEObjectType; +import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier; +import com.nimbusds.jose.proc.JOSEObjectTypeVerifier; +import com.nimbusds.jose.proc.JWKSecurityContext; import reactor.core.publisher.Mono; import org.springframework.security.authentication.AuthenticationProvider; @@ -23,19 +27,22 @@ import org.springframework.security.authentication.AuthenticationServiceExceptio import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.oauth2.client.oidc.authentication.ReactiveOidcIdTokenDecoderFactory; +import org.springframework.security.oauth2.client.oidc.authentication.OidcIdTokenDecoderFactory; import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * An {@link AuthenticationProvider} that authenticates an OIDC Logout Token; namely @@ -61,9 +68,27 @@ final class OidcBackChannelLogoutReactiveAuthenticationManager implements Reacti * Construct an {@link OidcBackChannelLogoutReactiveAuthenticationManager} */ OidcBackChannelLogoutReactiveAuthenticationManager() { - ReactiveOidcIdTokenDecoderFactory logoutTokenDecoderFactory = new ReactiveOidcIdTokenDecoderFactory(); - logoutTokenDecoderFactory.setJwtValidatorFactory(new DefaultOidcLogoutTokenValidatorFactory()); - this.logoutTokenDecoderFactory = logoutTokenDecoderFactory; + DefaultOidcLogoutTokenValidatorFactory jwtValidator = new DefaultOidcLogoutTokenValidatorFactory(); + this.logoutTokenDecoderFactory = (clientRegistration) -> { + String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri(); + if (!StringUtils.hasText(jwkSetUri)) { + OAuth2Error oauth2Error = new OAuth2Error("missing_signature_verifier", + "Failed to find a Signature Verifier for Client Registration: '" + + clientRegistration.getRegistrationId() + + "'. Check to ensure you have configured the JwkSet URI.", + null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + JOSEObjectTypeVerifier typeVerifier = new DefaultJOSEObjectTypeVerifier<>(null, + JOSEObjectType.JWT, new JOSEObjectType("logout+jwt")); + NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri) + .jwtProcessorCustomizer((processor) -> processor.setJWSTypeVerifier(typeVerifier)) + .build(); + decoder.setJwtValidator(jwtValidator.apply(clientRegistration)); + decoder.setClaimSetConverter( + new ClaimTypeConverter(OidcIdTokenDecoderFactory.createDefaultClaimTypeConverters())); + return decoder; + }; } /** diff --git a/config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java index 288c4f3a40..8514528aa7 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java @@ -75,6 +75,8 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JwsHeader; import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoderParameters; @@ -819,8 +821,9 @@ public class OidcLogoutSpecTests { String logoutToken(@AuthenticationPrincipal OidcUser user) { OidcLogoutToken token = TestOidcLogoutTokens.withUser(user) .audience(List.of(this.registration.getClientId())).build(); - JwtEncoderParameters parameters = JwtEncoderParameters - .from(JwtClaimsSet.builder().claims((claims) -> claims.putAll(token.getClaims())).build()); + JwsHeader header = JwsHeader.with(SignatureAlgorithm.RS256).type("logout+jwt").build(); + JwtClaimsSet claims = JwtClaimsSet.builder().claims((c) -> c.putAll(token.getClaims())).build(); + JwtEncoderParameters parameters = JwtEncoderParameters.from(header, claims); return this.encoder.encode(parameters).getTokenValue(); } @@ -829,8 +832,9 @@ public class OidcLogoutSpecTests { OidcLogoutToken token = TestOidcLogoutTokens.withUser(user) .audience(List.of(this.registration.getClientId())) .claims((claims) -> claims.remove(LogoutTokenClaimNames.SID)).build(); - JwtEncoderParameters parameters = JwtEncoderParameters - .from(JwtClaimsSet.builder().claims((claims) -> claims.putAll(token.getClaims())).build()); + JwsHeader header = JwsHeader.with(SignatureAlgorithm.RS256).type("JWT").build(); + JwtClaimsSet claims = JwtClaimsSet.builder().claims((c) -> c.putAll(token.getClaims())).build(); + JwtEncoderParameters parameters = JwtEncoderParameters.from(header, claims); return this.encoder.encode(parameters).getTokenValue(); } }