diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 73c2b9be44..6ac6e712b1 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -230,7 +230,8 @@ public final class NimbusJwtDecoder implements JwtDecoder { .getConfigurationForIssuerLocation(issuer, rest); JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer); return configuration.get("jwks_uri").toString(); - }, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms); + }, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms) + .validator(JwtValidators.createDefaultWithIssuer(issuer)); } /** @@ -289,6 +290,8 @@ public final class NimbusJwtDecoder implements JwtDecoder { private Consumer> jwtProcessorCustomizer; + private OAuth2TokenValidator validator = JwtValidators.createDefault(); + private JwkSetUriJwtDecoderBuilder(String jwkSetUri) { Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty"); this.jwkSetUri = (rest) -> jwkSetUri; @@ -423,6 +426,12 @@ public final class NimbusJwtDecoder implements JwtDecoder { return this; } + JwkSetUriJwtDecoderBuilder validator(OAuth2TokenValidator validator) { + Assert.notNull(validator, "validator cannot be null"); + this.validator = validator; + return this; + } + JWSKeySelector jwsKeySelector(JWKSource jwkSource) { if (this.signatureAlgorithms.isEmpty()) { return new JWSVerificationKeySelector<>(this.defaultAlgorithms.apply(jwkSource), jwkSource); @@ -461,7 +470,9 @@ public final class NimbusJwtDecoder implements JwtDecoder { * @return the configured {@link NimbusJwtDecoder} */ public NimbusJwtDecoder build() { - return new NimbusJwtDecoder(processor()); + NimbusJwtDecoder decoder = new NimbusJwtDecoder(processor()); + decoder.setJwtValidator(this.validator); + return decoder; } private static final class SpringJWKSource implements JWKSetSource { diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index f3a38d812b..b31213dff9 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -241,7 +241,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } return Mono.just(configuration.get("jwks_uri").toString()); }), - ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms); + ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms) + .validator(JwtValidators.createDefaultWithIssuer(issuer)); } /** @@ -332,6 +333,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { private BiFunction, Mono>> jwtProcessorCustomizer; + private OAuth2TokenValidator validator = JwtValidators.createDefault(); + private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) { Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty"); this.jwkSetUri = (web) -> Mono.just(jwkSetUri); @@ -456,6 +459,11 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return this; } + JwkSetUriReactiveJwtDecoderBuilder validator(OAuth2TokenValidator validator) { + this.validator = validator; + return this; + } + JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer( BiFunction, Mono>> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); @@ -468,7 +476,9 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { * @return the configured {@link NimbusReactiveJwtDecoder} */ public NimbusReactiveJwtDecoder build() { - return new NimbusReactiveJwtDecoder(processor()); + NimbusReactiveJwtDecoder decoder = new NimbusReactiveJwtDecoder(processor()); + decoder.setJwtValidator(this.validator); + return decoder; } Mono> jwsKeySelector(ReactiveRemoteJWKSource source) { diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index ef8de01096..dc49325da2 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -328,11 +328,26 @@ public class NimbusJwtDecoderTests { .willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK)); given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); - JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build(); + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer) + .restOperations(restOperations) + .build(); + jwtDecoder.setJwtValidator(JwtValidators.createDefault()); Jwt jwt = jwtDecoder.decode(SIGNED_JWT); assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); } + @Test + public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() { + String issuer = "https://example.org/wrong-issuer"; + RestOperations restOperations = mock(RestOperations.class); + given(restOperations.exchange(any(RequestEntity.class), any(ParameterizedTypeReference.class))) + .willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK)); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); + JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build(); + assertThatExceptionOfType(JwtValidationException.class).isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)); + } + @Test public void withJwkSetUriWhenNullOrEmptyThenThrowsException() { // @formatter:off diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index 5066339c88..e775e2618c 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -617,11 +617,31 @@ public class NimbusReactiveJwtDecoderTests { given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class))) .willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"))); given(spec.retrieve()).willReturn(responseSpec); + NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer) + .webClient(webClient) + .build(); + jwtDecoder.setJwtValidator(JwtValidators.createDefault()); + Jwt jwt = jwtDecoder.decode(this.messageReadToken).block(); + assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); + } + + @Test + public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() { + String issuer = "https://example.org/wrong-issuer"; + WebClient real = WebClient.builder().build(); + WebClient.RequestHeadersUriSpec spec = spy(real.get()); + WebClient webClient = spy(WebClient.class); + given(webClient.get()).willReturn(spec); + WebClient.ResponseSpec responseSpec = mock(WebClient.ResponseSpec.class); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(this.jwkSet)); + given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class))) + .willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"))); + given(spec.retrieve()).willReturn(responseSpec); ReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer) .webClient(webClient) .build(); - Jwt jwt = jwtDecoder.decode(this.messageReadToken).block(); - assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> jwtDecoder.decode(this.messageReadToken).block()); } @Test