diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index da80e3e40d..97b48ca701 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -36,12 +36,16 @@ import java.util.concurrent.Callable; import javax.crypto.SecretKey; +import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JOSEObjectType; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSSigner; import com.nimbusds.jose.crypto.MACSigner; import com.nimbusds.jose.crypto.RSASSASigner; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.BadJOSEException; import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier; @@ -82,6 +86,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -660,6 +665,81 @@ public class NimbusJwtDecoderTests { verifyNoInteractions(restOperations); } + @Test + public void decodeWhenCacheAndUnknownKidShouldTriggerFetchOfJwkSet() throws JOSEException { + RestOperations restOperations = mock(RestOperations.class); + + Cache cache = mock(Cache.class); + given(cache.get(eq(JWK_SET_URI), any(Callable.class))).willReturn(JWK_SET); + + RSAKey rsaJWK = new RSAKeyGenerator(2048) + .keyID("new_kid") + .generate(); + String jwkSetWithNewKid = new JWKSet(rsaJWK).toPublicJWKSet().toString(); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(jwkSetWithNewKid, HttpStatus.OK)); + + // @formatter:off + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) + .cache(cache) + .restOperations(restOperations) + .build(); + // @formatter:on + + // Decode JWT with new KID + JWSSigner signer = new RSASSASigner(rsaJWK); + JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() + .expirationTime(Date.from(Instant.now().plusSeconds(60))) + .build(); + SignedJWT signedJWT = new SignedJWT(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJWK.getKeyID()).build(), claimsSet); + signedJWT.sign(signer); + String token = signedJWT.serialize(); + + jwtDecoder.decode(token); + + ArgumentCaptor requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class); + verify(restOperations).exchange(requestEntityCaptor.capture(), eq(String.class)); + verifyNoMoreInteractions(restOperations); + assertThat(requestEntityCaptor.getValue().getHeaders().getAccept()).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON); + } + + @Test + public void decodeWithoutCacheSpecifiedAndUnknownKidShouldTriggerFetchOfJwkSet() throws JOSEException { + RestOperations restOperations = mock(RestOperations.class); + + RSAKey rsaJWK = new RSAKeyGenerator(2048) + .keyID("new_kid") + .generate(); + String jwkSetWithNewKid = new JWKSet(rsaJWK).toPublicJWKSet().toString(); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK), new ResponseEntity<>(jwkSetWithNewKid, HttpStatus.OK)); + + // @formatter:off + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) + .restOperations(restOperations) + .build(); + // @formatter:on + jwtDecoder.decode(SIGNED_JWT); + + // Decode JWT with new KID + JWSSigner signer = new RSASSASigner(rsaJWK); + JWTClaimsSet claimsSet = new JWTClaimsSet.Builder() + .expirationTime(Date.from(Instant.now().plusSeconds(60))) + .build(); + SignedJWT signedJWT = new SignedJWT(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJWK.getKeyID()).build(), claimsSet); + signedJWT.sign(signer); + String token = signedJWT.serialize(); + + jwtDecoder.decode(token); + + ArgumentCaptor requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class); + verify(restOperations, times(2)).exchange(requestEntityCaptor.capture(), eq(String.class)); + verifyNoMoreInteractions(restOperations); + List requestEntities = requestEntityCaptor.getAllValues(); + assertThat(requestEntities.get(0).getHeaders().getAccept()).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON); + assertThat(requestEntities.get(1).getHeaders().getAccept()).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON); + } + @Test public void decodeWhenCacheIsConfiguredAndValueLoaderErrorsThenThrowsJwtException() { Cache cache = new ConcurrentMapCache("test-jwk-set-cache");