From eb6ed283e0fe4dc8bacfdb53ff4844013c35ff5a Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 16 Jun 2021 09:48:52 -0400 Subject: [PATCH] Jwt client authentication converter detects new key Closes gh-9814 --- ...ientAuthenticationParametersConverter.java | 33 +++++++++-- ...uthenticationParametersConverterTests.java | 56 +++++++++++++++++++ 2 files changed, 85 insertions(+), 4 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index b7c80f17b5..68bba82d5c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -80,7 +80,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter jwkResolver; - private final Map jwsEncoders = new ConcurrentHashMap<>(); + private final Map jwsEncoders = new ConcurrentHashMap<>(); /** * Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the @@ -140,12 +140,16 @@ public final class NimbusJwtClientAuthenticationParametersConverter { + JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(), + (clientRegistrationId, currentJwsEncoderHolder) -> { + if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) { + return currentJwsEncoderHolder; + } JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); - return new NimbusJwsEncoder(jwkSource); + return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk); }); + NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder(); Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet); MultiValueMap parameters = new LinkedMultiValueMap<>(); @@ -180,4 +184,25 @@ public final class NimbusJwtClientAuthenticationParametersConverter parameters = this.converter.convert(clientCredentialsGrantRequest); + + String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build(); + jwtDecoder.decode(encodedJws); + + RSAKey rsaJwk2 = generateRsaJwk(); + given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2); + + parameters = this.converter.convert(clientCredentialsGrantRequest); + + encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); + jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build(); + jwtDecoder.decode(encodedJws); + } + + private static RSAKey generateRsaJwk() { + KeyPair keyPair; + try { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + keyPair = keyPairGenerator.generateKeyPair(); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic(); + RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate(); + // @formatter:off + return new RSAKey.Builder(publicKey) + .privateKey(privateKey) + .keyID(UUID.randomUUID().toString()) + .build(); + // @formatter:on + } + }