diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index b7c80f17b5..68bba82d5c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -80,7 +80,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter jwkResolver; - private final Map jwsEncoders = new ConcurrentHashMap<>(); + private final Map jwsEncoders = new ConcurrentHashMap<>(); /** * Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the @@ -140,12 +140,16 @@ public final class NimbusJwtClientAuthenticationParametersConverter { + JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(), + (clientRegistrationId, currentJwsEncoderHolder) -> { + if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) { + return currentJwsEncoderHolder; + } JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); - return new NimbusJwsEncoder(jwkSource); + return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk); }); + NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder(); Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet); MultiValueMap parameters = new LinkedMultiValueMap<>(); @@ -180,4 +184,25 @@ public final class NimbusJwtClientAuthenticationParametersConverter parameters = this.converter.convert(clientCredentialsGrantRequest); + + String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build(); + jwtDecoder.decode(encodedJws); + + RSAKey rsaJwk2 = generateRsaJwk(); + given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2); + + parameters = this.converter.convert(clientCredentialsGrantRequest); + + encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); + jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build(); + jwtDecoder.decode(encodedJws); + } + + private static RSAKey generateRsaJwk() { + KeyPair keyPair; + try { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + keyPair = keyPairGenerator.generateKeyPair(); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic(); + RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate(); + // @formatter:off + return new RSAKey.Builder(publicKey) + .privateKey(privateKey) + .keyID(UUID.randomUUID().toString()) + .build(); + // @formatter:on + } + }