Jwt client authentication converter detects new key

Closes gh-9814
This commit is contained in:
Joe Grandja 2021-06-16 09:48:52 -04:00
parent 700bda68b7
commit 6fbd038111
2 changed files with 85 additions and 4 deletions

View File

@ -80,7 +80,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
private final Function<ClientRegistration, JWK> jwkResolver; private final Function<ClientRegistration, JWK> jwkResolver;
private final Map<String, NimbusJwsEncoder> jwsEncoders = new ConcurrentHashMap<>(); private final Map<String, JwsEncoderHolder> jwsEncoders = new ConcurrentHashMap<>();
/** /**
* Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the * Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
@ -140,12 +140,16 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
JoseHeader joseHeader = headersBuilder.build(); JoseHeader joseHeader = headersBuilder.build();
JwtClaimsSet jwtClaimsSet = claimsBuilder.build(); JwtClaimsSet jwtClaimsSet = claimsBuilder.build();
NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(), JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(),
(clientRegistrationId) -> { (clientRegistrationId, currentJwsEncoderHolder) -> {
if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) {
return currentJwsEncoderHolder;
}
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
return new NimbusJwsEncoder(jwkSource); return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk);
}); });
NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder();
Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet); Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(); MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
@ -180,4 +184,25 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
return jwsAlgorithm; return jwsAlgorithm;
} }
private static final class JwsEncoderHolder {
private final NimbusJwsEncoder jwsEncoder;
private final JWK jwk;
private JwsEncoderHolder(NimbusJwsEncoder jwsEncoder, JWK jwk) {
this.jwsEncoder = jwsEncoder;
this.jwk = jwk;
}
private NimbusJwsEncoder getJwsEncoder() {
return this.jwsEncoder;
}
private JWK getJwk() {
return this.jwk;
}
}
} }

View File

@ -16,7 +16,12 @@
package org.springframework.security.oauth2.client.endpoint; package org.springframework.security.oauth2.client.endpoint;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.util.Collections; import java.util.Collections;
import java.util.UUID;
import java.util.function.Function; import java.util.function.Function;
import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWK;
@ -42,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
@ -172,4 +178,54 @@ public class NimbusJwtClientAuthenticationParametersConverterTests {
assertThat(jws.getExpiresAt()).isNotNull(); assertThat(jws.getExpiresAt()).isNotNull();
} }
// gh-9814
@Test
public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception {
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
RSAKey rsaJwk1 = TestJwks.DEFAULT_RSA_JWK;
given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk1);
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);
String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build();
jwtDecoder.decode(encodedJws);
RSAKey rsaJwk2 = generateRsaJwk();
given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2);
parameters = this.converter.convert(clientCredentialsGrantRequest);
encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build();
jwtDecoder.decode(encodedJws);
}
private static RSAKey generateRsaJwk() {
KeyPair keyPair;
try {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
keyPair = keyPairGenerator.generateKeyPair();
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();
// @formatter:off
return new RSAKey.Builder(publicKey)
.privateKey(privateKey)
.keyID(UUID.randomUUID().toString())
.build();
// @formatter:on
}
} }