diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 5fb07130e1..62f917e3cb 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -235,7 +235,8 @@ public final class NimbusJwtDecoder implements JwtDecoder { Object jwksUri = configuration.get("jwks_uri"); Assert.notNull(jwksUri, "The public JWK Set URI must not be null"); return jwksUri.toString(); - }, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms); + }, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms) + .validator(JwtValidators.createDefaultWithIssuer(issuer)); } /** @@ -304,6 +305,8 @@ public final class NimbusJwtDecoder implements JwtDecoder { private Consumer> jwtProcessorCustomizer; + private OAuth2TokenValidator validator = JwtValidators.createDefault(); + private JwkSetUriJwtDecoderBuilder(String jwkSetUri) { Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty"); this.jwkSetUri = (rest) -> jwkSetUri; @@ -444,6 +447,12 @@ public final class NimbusJwtDecoder implements JwtDecoder { return this; } + JwkSetUriJwtDecoderBuilder validator(OAuth2TokenValidator validator) { + Assert.notNull(validator, "validator cannot be null"); + this.validator = validator; + return this; + } + JWSKeySelector jwsKeySelector(JWKSource jwkSource) { if (this.signatureAlgorithms.isEmpty()) { return new JWSVerificationKeySelector<>(this.defaultAlgorithms.apply(jwkSource), jwkSource); @@ -482,7 +491,9 @@ public final class NimbusJwtDecoder implements JwtDecoder { * @return the configured {@link NimbusJwtDecoder} */ public NimbusJwtDecoder build() { - return new NimbusJwtDecoder(processor()); + NimbusJwtDecoder decoder = new NimbusJwtDecoder(processor()); + decoder.setJwtValidator(this.validator); + return decoder; } private static final class SpringJWKSource implements JWKSetSource { diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 4e103f2daa..4e5c357a18 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -244,7 +244,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { Assert.notNull(jwksUri, "The public JWK Set URI must not be null"); return Mono.just(jwksUri.toString()); }), - ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms); + ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms) + .validator(JwtValidators.createDefaultWithIssuer(issuer)); } /** @@ -335,6 +336,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { private BiFunction, Mono>> jwtProcessorCustomizer; + private OAuth2TokenValidator validator = JwtValidators.createDefault(); + private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) { Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty"); this.jwkSetUri = (web) -> Mono.just(jwkSetUri); @@ -459,6 +462,11 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return this; } + JwkSetUriReactiveJwtDecoderBuilder validator(OAuth2TokenValidator validator) { + this.validator = validator; + return this; + } + JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer( BiFunction, Mono>> jwtProcessorCustomizer) { Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null"); @@ -471,7 +479,9 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { * @return the configured {@link NimbusReactiveJwtDecoder} */ public NimbusReactiveJwtDecoder build() { - return new NimbusReactiveJwtDecoder(processor()); + NimbusReactiveJwtDecoder decoder = new NimbusReactiveJwtDecoder(processor()); + decoder.setJwtValidator(this.validator); + return decoder; } Mono> jwsKeySelector(ReactiveRemoteJWKSource source) { diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index c1866720a4..6bd1113e0b 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -332,7 +332,10 @@ public class NimbusJwtDecoderTests { .willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK)); given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); - JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build(); + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer) + .restOperations(restOperations) + .build(); + jwtDecoder.setJwtValidator(JwtValidators.createDefault()); Jwt jwt = jwtDecoder.decode(SIGNED_JWT); assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); } @@ -350,6 +353,18 @@ public class NimbusJwtDecoderTests { assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); } + @Test + public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() { + String issuer = "https://example.org/wrong-issuer"; + RestOperations restOperations = mock(RestOperations.class); + given(restOperations.exchange(any(RequestEntity.class), any(ParameterizedTypeReference.class))) + .willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK)); + given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) + .willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK)); + JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build(); + assertThatExceptionOfType(JwtValidationException.class).isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT)); + } + @Test public void withJwkSetUriWhenNullOrEmptyThenThrowsException() { // @formatter:off diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index beb5966b9b..68584c28be 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -617,11 +617,31 @@ public class NimbusReactiveJwtDecoderTests { given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class))) .willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"))); given(spec.retrieve()).willReturn(responseSpec); + NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer) + .webClient(webClient) + .build(); + jwtDecoder.setJwtValidator(JwtValidators.createDefault()); + Jwt jwt = jwtDecoder.decode(this.messageReadToken).block(); + assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); + } + + @Test + public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() { + String issuer = "https://example.org/wrong-issuer"; + WebClient real = WebClient.builder().build(); + WebClient.RequestHeadersUriSpec spec = spy(real.get()); + WebClient webClient = spy(WebClient.class); + given(webClient.get()).willReturn(spec); + WebClient.ResponseSpec responseSpec = mock(WebClient.ResponseSpec.class); + given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(this.jwkSet)); + given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class))) + .willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"))); + given(spec.retrieve()).willReturn(responseSpec); ReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer) .webClient(webClient) .build(); - Jwt jwt = jwtDecoder.decode(this.messageReadToken).block(); - assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull(); + assertThatExceptionOfType(JwtValidationException.class) + .isThrownBy(() -> jwtDecoder.decode(this.messageReadToken).block()); } @Test