From 0c3754c81133afd33bfdac28569473d8ac2cdaff Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 31 Jan 2020 16:41:52 -0700 Subject: [PATCH] Add BadJwtException Updated NimbusJwtDecoder and NimbusReactiveJwtDecoder to throw. Updated JwtAuthenticationProvider and JwtReactiveAuthenticationManager to catch. Fixes gh-7885 --- .../OAuth2ResourceServerConfigurerTests.java | 7 ++-- .../security/oauth2/jwt/BadJwtException.java | 34 +++++++++++++++++++ .../oauth2/jwt/JwtValidationException.java | 4 +-- .../security/oauth2/jwt/NimbusJwtDecoder.java | 13 ++++--- .../oauth2/jwt/NimbusReactiveJwtDecoder.java | 33 +++++++++++------- .../oauth2/jwt/SingleKeyJWSKeySelector.java | 4 +-- .../oauth2/jwt/NimbusJwtDecoderTests.java | 26 ++++++++++---- .../jwt/NimbusReactiveJwtDecoderTests.java | 30 +++++++++++----- .../JwtAuthenticationProvider.java | 6 +++- .../JwtReactiveAuthenticationManager.java | 12 +++++-- .../JwtAuthenticationProviderTests.java | 18 ++++++++-- ...JwtReactiveAuthenticationManagerTests.java | 17 ++++++++-- 12 files changed, 157 insertions(+), 47 deletions(-) create mode 100644 oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/BadJwtException.java diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java index e4427541f1..bd90cd2dcc 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java @@ -94,6 +94,7 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtException; @@ -256,7 +257,7 @@ public class OAuth2ResourceServerConfigurerTests { this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) - .andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt: Malformed Jwk set")); + .andExpect(header().string("WWW-Authenticate", "Bearer")); } @Test @@ -269,7 +270,7 @@ public class OAuth2ResourceServerConfigurerTests { this.mvc.perform(get("/").with(bearerToken(token))) .andExpect(status().isUnauthorized()) - .andExpect(invalidTokenHeader("Invalid token")); + .andExpect(header().string("WWW-Authenticate", "Bearer")); } @Test @@ -1099,7 +1100,7 @@ public class OAuth2ResourceServerConfigurerTests { this.spring.register(CustomAuthenticationEventPublisher.class).autowire(); when(bean(JwtDecoder.class).decode(anyString())) - .thenThrow(new JwtException("problem")); + .thenThrow(new BadJwtException("problem")); this.mvc.perform(get("/").with(bearerToken("token"))); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/BadJwtException.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/BadJwtException.java new file mode 100644 index 0000000000..11aa05b9f9 --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/BadJwtException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +/** + * An exception similar to {@link org.springframework.security.authentication.BadCredentialsException} + * that indicates a {@link Jwt} that is invalid in some way. + * + * @author Josh Cummings + * @since 5.3 + */ +public class BadJwtException extends JwtException { + public BadJwtException(String message) { + super(message); + } + + public BadJwtException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidationException.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidationException.java index 7c300941b7..3ea9e11050 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidationException.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidationException.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ import org.springframework.util.Assert; * @author Josh Cummings * @since 5.1 */ -public class JwtValidationException extends JwtException { +public class JwtValidationException extends BadJwtException { private final Collection errors; /** diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 8478a58bf6..68c2318a1b 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import java.util.Set; import java.util.function.Consumer; import javax.crypto.SecretKey; +import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.jwk.source.JWKSource; @@ -120,7 +121,7 @@ public final class NimbusJwtDecoder implements JwtDecoder { public Jwt decode(String token) throws JwtException { JWT jwt = parse(token); if (jwt instanceof PlainJWT) { - throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); + throw new BadJwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); } Jwt createdJwt = createJwt(token, jwt); return validateJwt(createdJwt); @@ -130,7 +131,7 @@ public final class NimbusJwtDecoder implements JwtDecoder { try { return JWTParser.parse(token); } catch (Exception ex) { - throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); + throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); } } @@ -152,11 +153,13 @@ public final class NimbusJwtDecoder implements JwtDecoder { } else { throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); } + } catch (JOSEException ex) { + throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); } catch (Exception ex) { if (ex.getCause() instanceof ParseException) { - throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed payload")); + throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed payload")); } else { - throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); + throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); } } } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 1779d49f50..c80bbb4a3a 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -136,7 +136,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { public Mono decode(String token) throws JwtException { JWT jwt = parse(token); if (jwt instanceof PlainJWT) { - throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); + throw new BadJwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); } return this.decode(jwt); } @@ -145,7 +145,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { try { return JWTParser.parse(token); } catch (Exception ex) { - throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); + throw new BadJwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); } } @@ -155,19 +155,25 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { .map(set -> createJwt(parsedToken, set)) .map(this::validateJwt) .onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e)); + } catch (JwtException ex) { + throw ex; } catch (RuntimeException ex) { throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); } } private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) { - Map headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject()); - Map claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); + try { + Map headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject()); + Map claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); - return Jwt.withTokenValue(parsedJwt.getParsedString()) - .headers(h -> h.putAll(headers)) - .claims(c -> c.putAll(claims)) - .build(); + return Jwt.withTokenValue(parsedJwt.getParsedString()) + .headers(h -> h.putAll(headers)) + .claims(c -> c.putAll(claims)) + .build(); + } catch (Exception ex) { + throw new BadJwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); + } } private Jwt validateJwt(Jwt jwt) { @@ -345,7 +351,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { private JWKSelector createSelector(Set expectedJwsAlgorithms, Header header) { if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) { - throw new JwtException("Unsupported algorithm of " + header.getAlgorithm()); + throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm()); } return new JWKSelector(JWKMatcher.forJWSHeader((JWSHeader) header)); @@ -514,7 +520,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { .collectList() .map(jwks -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwks))); } - throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); + throw new BadJwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); }; } } @@ -524,7 +530,10 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { try { return jwtProcessor.process(parsedToken, context); } - catch (BadJOSEException | JOSEException e) { + catch (BadJOSEException e) { + throw new BadJwtException("Failed to validate the token", e); + } + catch (JOSEException e) { throw new JwtException("Failed to validate the token", e); } } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/SingleKeyJWSKeySelector.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/SingleKeyJWSKeySelector.java index 97d5d3a663..4111a2866e 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/SingleKeyJWSKeySelector.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/SingleKeyJWSKeySelector.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ final class SingleKeyJWSKeySelector implements JWSKey @Override public List selectJWSKeys(JWSHeader header, C context) { if (!this.expectedJwsAlgorithm.equals(header.getAlgorithm())) { - throw new IllegalArgumentException("Unsupported algorithm of " + header.getAlgorithm()); + throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm()); } return this.keySet; } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index c6fec7d112..27ab00ff35 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -132,7 +132,7 @@ public class NimbusJwtDecoderTests { @Test public void decodeWhenJwtInvalidThenThrowJwtException() { assertThatThrownBy(() -> this.jwtDecoder.decode("invalid")) - .isInstanceOf(JwtException.class); + .isInstanceOf(BadJwtException.class); } // gh-5168 @@ -152,14 +152,14 @@ public class NimbusJwtDecoderTests { @Test public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() { assertThatCode(() -> this.jwtDecoder.decode(UNSIGNED_JWT)) - .isInstanceOf(JwtException.class) + .isInstanceOf(BadJwtException.class) .hasMessageContaining("Unsupported algorithm of none"); } @Test public void decodeWhenJwtIsMalformedThenReturnsStockException() { assertThatCode(() -> this.jwtDecoder.decode(MALFORMED_JWT)) - .isInstanceOf(JwtException.class) + .isInstanceOf(BadJwtException.class) .hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload"); } @@ -205,6 +205,18 @@ public class NimbusJwtDecoderTests { assertThat(jwt.getClaims().get("custom")).isEqualTo("value"); } + // gh-7885 + @Test + public void decodeWhenClaimSetConverterFailsThenBadJwtException() { + Converter, Map> claimSetConverter = mock(Converter.class); + this.jwtDecoder.setClaimSetConverter(claimSetConverter); + + when(claimSetConverter.convert(any(Map.class))).thenThrow(new IllegalArgumentException("bad conversion")); + + assertThatCode(() -> this.jwtDecoder.decode(SIGNED_JWT)) + .isInstanceOf(BadJwtException.class); + } + @Test public void decodeWhenSignedThenOk() { NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(withSigning(JWK_SET)); @@ -217,6 +229,7 @@ public class NimbusJwtDecoderTests { NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(withSigning(MALFORMED_JWK_SET)); assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) .isInstanceOf(JwtException.class) + .isNotInstanceOf(BadJwtException.class) .hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set"); } @@ -229,6 +242,7 @@ public class NimbusJwtDecoderTests { server.shutdown(); assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) .isInstanceOf(JwtException.class) + .isNotInstanceOf(BadJwtException.class) .hasMessageContaining("An error occurred while attempting to decode the Jwt"); } } @@ -301,7 +315,7 @@ public class NimbusJwtDecoderTests { public void decodeWhenSignatureMismatchesAlgorithmThenThrowsException() throws Exception { NimbusJwtDecoder decoder = withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build(); Assertions.assertThatCode(() -> decoder.decode(RS256_SIGNED_JWT)) - .isInstanceOf(JwtException.class); + .isInstanceOf(BadJwtException.class); } @Test @@ -345,7 +359,7 @@ public class NimbusJwtDecoderTests { SignedJWT signedJWT = signedJwt(secretKey, macAlgorithm, claimsSet); NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build(); assertThatThrownBy(() -> decoder.decode(signedJWT.serialize())) - .isInstanceOf(JwtException.class) + .isInstanceOf(BadJwtException.class) .hasMessageContaining("Unsupported algorithm of HS256"); } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index 59a361141a..74a9d8f671 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -171,7 +171,7 @@ public class NimbusReactiveJwtDecoderTests { @Test public void decodeWhenNoPeriodThenFail() { assertThatCode(() -> this.decoder.decode("").block()) - .isInstanceOf(JwtException.class); + .isInstanceOf(BadJwtException.class); } @Test @@ -184,26 +184,26 @@ public class NimbusReactiveJwtDecoderTests { @Test public void decodeWhenInvalidSignatureThenFail() { assertThatCode(() -> this.decoder.decode(this.messageReadToken.substring(0, this.messageReadToken.length() - 2)).block()) - .isInstanceOf(JwtException.class); + .isInstanceOf(BadJwtException.class); } @Test public void decodeWhenAlgNoneThenFail() { assertThatCode(() -> this.decoder.decode("ew0KICAiYWxnIjogIm5vbmUiLA0KICAidHlwIjogIkpXVCINCn0.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.").block()) - .isInstanceOf(JwtException.class) + .isInstanceOf(BadJwtException.class) .hasMessage("Unsupported algorithm of none"); } @Test public void decodeWhenInvalidAlgMismatchThenFail() { assertThatCode(() -> this.decoder.decode("ew0KICAiYWxnIjogIkVTMjU2IiwNCiAgInR5cCI6ICJKV1QiDQp9.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.").block()) - .isInstanceOf(JwtException.class); + .isInstanceOf(BadJwtException.class); } @Test public void decodeWhenUnsignedTokenThenMessageDoesNotMentionClass() { assertThatCode(() -> this.decoder.decode(this.unsignedToken).block()) - .isInstanceOf(JwtException.class) + .isInstanceOf(BadJwtException.class) .hasMessage("Unsupported algorithm of none"); } @@ -217,7 +217,7 @@ public class NimbusReactiveJwtDecoderTests { when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) - .isInstanceOf(JwtException.class) + .isInstanceOf(JwtValidationException.class) .hasMessageContaining("mock-description"); } @@ -234,6 +234,18 @@ public class NimbusReactiveJwtDecoderTests { verify(claimSetConverter).convert(any(Map.class)); } + // gh-7885 + @Test + public void decodeWhenClaimSetConverterFailsThenBadJwtException() { + Converter, Map> claimSetConverter = mock(Converter.class); + this.decoder.setClaimSetConverter(claimSetConverter); + + when(claimSetConverter.convert(any(Map.class))).thenThrow(new IllegalArgumentException("bad conversion")); + + assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) + .isInstanceOf(BadJwtException.class); + } + @Test public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() { assertThatCode(() -> this.decoder.setJwtValidator(null)) @@ -310,7 +322,7 @@ public class NimbusReactiveJwtDecoderTests { NimbusReactiveJwtDecoder decoder = withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build(); assertThatCode(() -> decoder.decode(this.rsa256).block()) - .isInstanceOf(JwtException.class); + .isInstanceOf(BadJwtException.class); } @Test @@ -372,7 +384,7 @@ public class NimbusReactiveJwtDecoderTests { this.decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build(); assertThatThrownBy(() -> this.decoder.decode(signedJWT.serialize()).block()) - .isInstanceOf(JwtException.class); + .isInstanceOf(BadJwtException.class); } @Test diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProvider.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProvider.java index 8eb80a947d..54b3cc6a67 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProvider.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProvider.java @@ -20,9 +20,11 @@ import java.util.Collection; import org.springframework.core.convert.converter.Converter; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtException; @@ -80,8 +82,10 @@ public final class JwtAuthenticationProvider implements AuthenticationProvider { Jwt jwt; try { jwt = this.jwtDecoder.decode(bearer.getToken()); - } catch (JwtException failed) { + } catch (BadJwtException failed) { throw new InvalidBearerTokenException(failed.getMessage(), failed); + } catch (JwtException failed) { + throw new AuthenticationServiceException(failed.getMessage(), failed); } AbstractAuthenticationToken token = this.jwtAuthenticationConverter.convert(jwt); diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManager.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManager.java index c475ef37c0..ab877b2181 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManager.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManager.java @@ -20,9 +20,11 @@ import reactor.core.publisher.Mono; import org.springframework.core.convert.converter.Converter; import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; @@ -71,7 +73,11 @@ public final class JwtReactiveAuthenticationManager implements ReactiveAuthentic this.jwtAuthenticationConverter = jwtAuthenticationConverter; } - private OAuth2AuthenticationException onError(JwtException e) { - return new InvalidBearerTokenException(e.getMessage(), e); + private AuthenticationException onError(JwtException e) { + if (e instanceof BadJwtException) { + return new InvalidBearerTokenException(e.getMessage(), e); + } else { + return new AuthenticationServiceException(e.getMessage(), e); + } } } diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProviderTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProviderTests.java index 936ebb0bbf..50db5d9187 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProviderTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtAuthenticationProviderTests.java @@ -24,7 +24,9 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.core.convert.converter.Converter; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtException; @@ -78,7 +80,7 @@ public class JwtAuthenticationProviderTests { public void authenticateWhenJwtDecodeFailsThenRespondsWithInvalidToken() { BearerTokenAuthenticationToken token = this.authentication(); - when(this.jwtDecoder.decode("token")).thenThrow(JwtException.class); + when(this.jwtDecoder.decode("token")).thenThrow(BadJwtException.class); assertThatCode(() -> this.provider.authenticate(token)) .matches(failed -> failed instanceof OAuth2AuthenticationException) @@ -89,7 +91,7 @@ public class JwtAuthenticationProviderTests { public void authenticateWhenDecoderThrowsIncompatibleErrorMessageThenWrapsWithGenericOne() { BearerTokenAuthenticationToken token = this.authentication(); - when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("with \"invalid\" chars")); + when(this.jwtDecoder.decode(token.getToken())).thenThrow(new BadJwtException("with \"invalid\" chars")); assertThatCode(() -> this.provider.authenticate(token)) .isInstanceOf(OAuth2AuthenticationException.class) @@ -98,6 +100,18 @@ public class JwtAuthenticationProviderTests { "Invalid token"); } + // gh-7785 + @Test + public void authenticateWhenDecoderFailsGenericallyThenThrowsGenericException() { + BearerTokenAuthenticationToken token = this.authentication(); + + when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("no jwk set")); + + assertThatCode(() -> this.provider.authenticate(token)) + .isInstanceOf(AuthenticationException.class) + .isNotInstanceOf(OAuth2AuthenticationException.class); + } + @Test public void authenticateWhenConverterReturnsAuthenticationThenProviderPropagatesIt() { BearerTokenAuthenticationToken token = this.authentication(); diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManagerTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManagerTests.java index e103dd4851..8317aaba58 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/authentication/JwtReactiveAuthenticationManagerTests.java @@ -25,8 +25,10 @@ import reactor.core.publisher.Mono; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; @@ -82,7 +84,7 @@ public class JwtReactiveAuthenticationManagerTests { @Test public void authenticateWhenJwtExceptionThenOAuth2AuthenticationException() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); - when(this.jwtDecoder.decode(any())).thenReturn(Mono.error(new JwtException("Oops"))); + when(this.jwtDecoder.decode(any())).thenReturn(Mono.error(new BadJwtException("Oops"))); assertThatCode(() -> this.manager.authenticate(token).block()) .isInstanceOf(OAuth2AuthenticationException.class); @@ -92,7 +94,7 @@ public class JwtReactiveAuthenticationManagerTests { @Test public void authenticateWhenDecoderThrowsIncompatibleErrorMessageThenWrapsWithGenericOne() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); - when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("with \"invalid\" chars")); + when(this.jwtDecoder.decode(token.getToken())).thenThrow(new BadJwtException("with \"invalid\" chars")); assertThatCode(() -> this.manager.authenticate(token).block()) .isInstanceOf(OAuth2AuthenticationException.class) @@ -101,6 +103,17 @@ public class JwtReactiveAuthenticationManagerTests { "Invalid token"); } + // gh-7785 + @Test + public void authenticateWhenDecoderFailsGenericallyThenThrowsGenericException() { + BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); + when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("no jwk set")); + + assertThatCode(() -> this.manager.authenticate(token).block()) + .isInstanceOf(AuthenticationException.class) + .isNotInstanceOf(OAuth2AuthenticationException.class); + } + @Test public void authenticateWhenNotJwtExceptionThenPropagates() { BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1");