Fix DPoP jkt claim to be JWK SHA-256 thumbprint

Just used the nimbus JOSE library to do it, because it already has a
compliant implementation.

Closes gh-17080

Signed-off-by: David Kowis <david@kow.is>
This commit is contained in:
David Kowis 2025-05-08 12:35:59 -05:00 committed by Joe Grandja
parent 8b925dc4fc
commit 462e38c0e3
2 changed files with 11 additions and 20 deletions

View File

@ -210,25 +210,22 @@ public final class DPoPAuthenticationProvider implements AuthenticationProvider
return OAuth2TokenValidatorResult.failure(error); return OAuth2TokenValidatorResult.failure(error);
} }
PublicKey publicKey = null; JWK jwk = null;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Map<String, Object> jwkJson = (Map<String, Object>) jwt.getHeaders().get("jwk"); Map<String, Object> jwkJson = (Map<String, Object>) jwt.getHeaders().get("jwk");
try { try {
JWK jwk = JWK.parse(jwkJson); jwk = JWK.parse(jwkJson);
if (jwk instanceof AsymmetricJWK) {
publicKey = ((AsymmetricJWK) jwk).toPublicKey();
}
} }
catch (Exception ignored) { catch (Exception ignored) {
} }
if (publicKey == null) { if (jwk == null) {
OAuth2Error error = createOAuth2Error("jwk header is missing or invalid."); OAuth2Error error = createOAuth2Error("jwk header is missing or invalid.");
return OAuth2TokenValidatorResult.failure(error); return OAuth2TokenValidatorResult.failure(error);
} }
String jwkThumbprint; String jwkThumbprint;
try { try {
jwkThumbprint = computeSHA256(publicKey); jwkThumbprint = jwk.computeThumbprint().toString();
} }
catch (Exception ex) { catch (Exception ex) {
OAuth2Error error = createOAuth2Error("Failed to compute SHA-256 Thumbprint for jwk."); OAuth2Error error = createOAuth2Error("Failed to compute SHA-256 Thumbprint for jwk.");

View File

@ -26,6 +26,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.proc.SecurityContext;
@ -218,8 +219,8 @@ public class DPoPAuthenticationProviderTests {
@Test @Test
public void authenticateWhenJktDoesNotMatchThenThrowOAuth2AuthenticationException() throws Exception { public void authenticateWhenJktDoesNotMatchThenThrowOAuth2AuthenticationException() throws Exception {
// Use different client public key // Use different jwk to make it not match
Jwt accessToken = generateAccessToken(TestKeys.DEFAULT_EC_KEY_PAIR.getPublic()); Jwt accessToken = generateAccessToken(TestJwks.DEFAULT_EC_JWK);
JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(accessToken); JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(accessToken);
given(this.tokenAuthenticationManager.authenticate(any())).willReturn(jwtAuthenticationToken); given(this.tokenAuthenticationManager.authenticate(any())).willReturn(jwtAuthenticationToken);
@ -285,14 +286,14 @@ public class DPoPAuthenticationProviderTests {
} }
private Jwt generateAccessToken() { private Jwt generateAccessToken() {
return generateAccessToken(TestKeys.DEFAULT_PUBLIC_KEY); return generateAccessToken(TestJwks.DEFAULT_RSA_JWK);
} }
private Jwt generateAccessToken(PublicKey clientPublicKey) { private Jwt generateAccessToken(JWK clientJwk) {
Map<String, Object> jktClaim = null; Map<String, Object> jktClaim = null;
if (clientPublicKey != null) { if (clientJwk != null) {
try { try {
String sha256Thumbprint = computeSHA256(clientPublicKey); String sha256Thumbprint = clientJwk.computeThumbprint().toString();
jktClaim = new HashMap<>(); jktClaim = new HashMap<>();
jktClaim.put("jkt", sha256Thumbprint); jktClaim.put("jkt", sha256Thumbprint);
} }
@ -321,11 +322,4 @@ public class DPoPAuthenticationProviderTests {
byte[] digest = md.digest(value.getBytes(StandardCharsets.UTF_8)); byte[] digest = md.digest(value.getBytes(StandardCharsets.UTF_8));
return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
} }
private static String computeSHA256(PublicKey publicKey) throws Exception {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] digest = md.digest(publicKey.getEncoded());
return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
}
} }