Jwt client authentication converter detects new key
Closes gh-9814
This commit is contained in:
parent
700bda68b7
commit
6fbd038111
|
@ -80,7 +80,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
|
||||||
|
|
||||||
private final Function<ClientRegistration, JWK> jwkResolver;
|
private final Function<ClientRegistration, JWK> jwkResolver;
|
||||||
|
|
||||||
private final Map<String, NimbusJwsEncoder> jwsEncoders = new ConcurrentHashMap<>();
|
private final Map<String, JwsEncoderHolder> jwsEncoders = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
|
* Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
|
||||||
|
@ -140,12 +140,16 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
|
||||||
JoseHeader joseHeader = headersBuilder.build();
|
JoseHeader joseHeader = headersBuilder.build();
|
||||||
JwtClaimsSet jwtClaimsSet = claimsBuilder.build();
|
JwtClaimsSet jwtClaimsSet = claimsBuilder.build();
|
||||||
|
|
||||||
NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(),
|
JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(),
|
||||||
(clientRegistrationId) -> {
|
(clientRegistrationId, currentJwsEncoderHolder) -> {
|
||||||
|
if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) {
|
||||||
|
return currentJwsEncoderHolder;
|
||||||
|
}
|
||||||
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
|
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
|
||||||
return new NimbusJwsEncoder(jwkSource);
|
return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder();
|
||||||
Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
|
Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
|
||||||
|
|
||||||
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
|
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
|
||||||
|
@ -180,4 +184,25 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
|
||||||
return jwsAlgorithm;
|
return jwsAlgorithm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final class JwsEncoderHolder {
|
||||||
|
|
||||||
|
private final NimbusJwsEncoder jwsEncoder;
|
||||||
|
|
||||||
|
private final JWK jwk;
|
||||||
|
|
||||||
|
private JwsEncoderHolder(NimbusJwsEncoder jwsEncoder, JWK jwk) {
|
||||||
|
this.jwsEncoder = jwsEncoder;
|
||||||
|
this.jwk = jwk;
|
||||||
|
}
|
||||||
|
|
||||||
|
private NimbusJwsEncoder getJwsEncoder() {
|
||||||
|
return this.jwsEncoder;
|
||||||
|
}
|
||||||
|
|
||||||
|
private JWK getJwk() {
|
||||||
|
return this.jwk;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,12 @@
|
||||||
|
|
||||||
package org.springframework.security.oauth2.client.endpoint;
|
package org.springframework.security.oauth2.client.endpoint;
|
||||||
|
|
||||||
|
import java.security.KeyPair;
|
||||||
|
import java.security.KeyPairGenerator;
|
||||||
|
import java.security.interfaces.RSAPrivateKey;
|
||||||
|
import java.security.interfaces.RSAPublicKey;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.UUID;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
|
||||||
import com.nimbusds.jose.jwk.JWK;
|
import com.nimbusds.jose.jwk.JWK;
|
||||||
|
@ -42,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.BDDMockito.given;
|
import static org.mockito.BDDMockito.given;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.verifyNoInteractions;
|
import static org.mockito.Mockito.verifyNoInteractions;
|
||||||
|
@ -172,4 +178,54 @@ public class NimbusJwtClientAuthenticationParametersConverterTests {
|
||||||
assertThat(jws.getExpiresAt()).isNotNull();
|
assertThat(jws.getExpiresAt()).isNotNull();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gh-9814
|
||||||
|
@Test
|
||||||
|
public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception {
|
||||||
|
// @formatter:off
|
||||||
|
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
|
||||||
|
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
|
||||||
|
.build();
|
||||||
|
// @formatter:on
|
||||||
|
|
||||||
|
RSAKey rsaJwk1 = TestJwks.DEFAULT_RSA_JWK;
|
||||||
|
given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk1);
|
||||||
|
|
||||||
|
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
|
||||||
|
clientRegistration);
|
||||||
|
MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);
|
||||||
|
|
||||||
|
String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
|
||||||
|
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build();
|
||||||
|
jwtDecoder.decode(encodedJws);
|
||||||
|
|
||||||
|
RSAKey rsaJwk2 = generateRsaJwk();
|
||||||
|
given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2);
|
||||||
|
|
||||||
|
parameters = this.converter.convert(clientCredentialsGrantRequest);
|
||||||
|
|
||||||
|
encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
|
||||||
|
jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build();
|
||||||
|
jwtDecoder.decode(encodedJws);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static RSAKey generateRsaJwk() {
|
||||||
|
KeyPair keyPair;
|
||||||
|
try {
|
||||||
|
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
|
||||||
|
keyPairGenerator.initialize(2048);
|
||||||
|
keyPair = keyPairGenerator.generateKeyPair();
|
||||||
|
}
|
||||||
|
catch (Exception ex) {
|
||||||
|
throw new IllegalStateException(ex);
|
||||||
|
}
|
||||||
|
RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
|
||||||
|
RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();
|
||||||
|
// @formatter:off
|
||||||
|
return new RSAKey.Builder(publicKey)
|
||||||
|
.privateKey(privateKey)
|
||||||
|
.keyID(UUID.randomUUID().toString())
|
||||||
|
.build();
|
||||||
|
// @formatter:on
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue