From acfe4bdcfba829b0929de703cd0b98010d775812 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 31 Jul 2020 08:59:32 -0600 Subject: [PATCH] Polish to Avoid NPE Issue gh-5648 Co-authored-by: MattyA --- .../oauth2/jwt/NimbusJwtDecoderJwkSupport.java | 15 +++++++++++++-- .../oauth2/jwt/NimbusReactiveJwtDecoder.java | 16 +++++++++++++--- .../jwt/NimbusJwtDecoderJwkSupportTests.java | 17 +++++++++++++++++ .../jwt/NimbusReactiveJwtDecoderTests.java | 17 +++++++++++++++++ 4 files changed, 60 insertions(+), 5 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java index 54d4754fea..471f9b2e74 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java @@ -21,6 +21,7 @@ import java.net.URL; import java.text.ParseException; import java.time.Instant; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; @@ -47,10 +48,12 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; @@ -190,9 +193,17 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { private Jwt validateJwt(Jwt jwt){ OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt); if (result.hasErrors()) { - String description = result.getErrors().iterator().next().getDescription(); + Collection errors = result.getErrors(); + String validationErrorString = "Unable to validate Jwt"; + for (OAuth2Error oAuth2Error : errors) { + if (!StringUtils.isEmpty(oAuth2Error.getDescription())) { + validationErrorString = String.format( + DECODING_ERROR_MESSAGE_TEMPLATE, oAuth2Error.getDescription()); + break; + } + } throw new JwtValidationException( - String.format(DECODING_ERROR_MESSAGE_TEMPLATE, description), + validationErrorString, result.getErrors()); } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 5376cf7ff3..aa4d99c9f8 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.jwt; import java.security.interfaces.RSAPublicKey; import java.time.Instant; +import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -40,10 +41,12 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor; import reactor.core.publisher.Mono; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * An implementation of a {@link ReactiveJwtDecoder} that "decodes" a @@ -184,9 +187,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { private Jwt validateJwt(Jwt jwt) { OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt); - if ( result.hasErrors() ) { - String message = result.getErrors().iterator().next().getDescription(); - throw new JwtValidationException(message, result.getErrors()); + if (result.hasErrors()) { + Collection errors = result.getErrors(); + String validationErrorString = "Unable to validate Jwt"; + for (OAuth2Error oAuth2Error : errors) { + if (!StringUtils.isEmpty(oAuth2Error.getDescription())) { + validationErrorString = oAuth2Error.getDescription(); + break; + } + } + throw new JwtValidationException(validationErrorString, errors); } return jwt; diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java index ca1e9d7561..8242e0ecca 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java @@ -33,6 +33,7 @@ import org.assertj.core.api.Assertions; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @@ -241,6 +242,22 @@ public class NimbusJwtDecoderJwkSupportTests { } } + @Test + public void decodeWhenReadingErrorPickTheFirstErrorMessage() { + OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); + this.jwtDecoder.setJwtValidator(jwtValidator); + + OAuth2Error errorEmpty = new OAuth2Error("mock-error", "", "mock-uri"); + OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri"); + OAuth2Error error2 = new OAuth2Error("mock-error-second", "mock-description-second", "mock-uri-second"); + OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(errorEmpty, error, error2); + Mockito.when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); + + Assertions.assertThatCode(() -> this.jwtDecoder.decode(SIGNED_JWT)) + .isInstanceOf(JwtValidationException.class) + .hasMessageContaining("mock-description"); + } + @Test public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() throws Exception { try ( MockWebServer server = new MockWebServer() ) { diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index cc976bae1c..5cce82b7f1 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -177,6 +177,23 @@ public class NimbusReactiveJwtDecoderTests { .hasMessageContaining("mock-description"); } + + @Test + public void decodeWhenReadingErrorPickTheFirstErrorMessage() { + OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class); + this.decoder.setJwtValidator(jwtValidator); + + OAuth2Error errorEmpty = new OAuth2Error("mock-error", "", "mock-uri"); + OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri"); + OAuth2Error error2 = new OAuth2Error("mock-error-second", "mock-description-second", "mock-uri-second"); + OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(errorEmpty, error, error2); + when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); + + assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) + .isInstanceOf(JwtValidationException.class) + .hasMessageContaining("mock-description"); + } + @Test public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() { assertThatCode(() -> this.decoder.setJwtValidator(null))