Add BadJwtException

Updated NimbusJwtDecoder and NimbusReactiveJwtDecoder to throw.
Updated JwtAuthenticationProvider and JwtReactiveAuthenticationManager
to catch.

Fixes gh-7885
This commit is contained in:
Josh Cummings 2020-01-31 16:41:52 -07:00
parent fbdecdafb8
commit 0c3754c811
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
12 changed files with 157 additions and 47 deletions

View File

@ -94,6 +94,7 @@ import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.JwtException;
@ -256,7 +257,7 @@ public class OAuth2ResourceServerConfigurerTests {
this.mvc.perform(get("/").with(bearerToken(token))) this.mvc.perform(get("/").with(bearerToken(token)))
.andExpect(status().isUnauthorized()) .andExpect(status().isUnauthorized())
.andExpect(invalidTokenHeader("An error occurred while attempting to decode the Jwt: Malformed Jwk set")); .andExpect(header().string("WWW-Authenticate", "Bearer"));
} }
@Test @Test
@ -269,7 +270,7 @@ public class OAuth2ResourceServerConfigurerTests {
this.mvc.perform(get("/").with(bearerToken(token))) this.mvc.perform(get("/").with(bearerToken(token)))
.andExpect(status().isUnauthorized()) .andExpect(status().isUnauthorized())
.andExpect(invalidTokenHeader("Invalid token")); .andExpect(header().string("WWW-Authenticate", "Bearer"));
} }
@Test @Test
@ -1099,7 +1100,7 @@ public class OAuth2ResourceServerConfigurerTests {
this.spring.register(CustomAuthenticationEventPublisher.class).autowire(); this.spring.register(CustomAuthenticationEventPublisher.class).autowire();
when(bean(JwtDecoder.class).decode(anyString())) when(bean(JwtDecoder.class).decode(anyString()))
.thenThrow(new JwtException("problem")); .thenThrow(new BadJwtException("problem"));
this.mvc.perform(get("/").with(bearerToken("token"))); this.mvc.perform(get("/").with(bearerToken("token")));

View File

@ -0,0 +1,34 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.jwt;
/**
* An exception similar to {@link org.springframework.security.authentication.BadCredentialsException}
* that indicates a {@link Jwt} that is invalid in some way.
*
* @author Josh Cummings
* @since 5.3
*/
public class BadJwtException extends JwtException {
public BadJwtException(String message) {
super(message);
}
public BadJwtException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,7 +29,7 @@ import org.springframework.util.Assert;
* @author Josh Cummings * @author Josh Cummings
* @since 5.1 * @since 5.1
*/ */
public class JwtValidationException extends JwtException { public class JwtValidationException extends BadJwtException {
private final Collection<OAuth2Error> errors; private final Collection<OAuth2Error> errors;
/** /**

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -31,6 +31,7 @@ import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.jwk.source.JWKSource;
@ -120,7 +121,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
public Jwt decode(String token) throws JwtException { public Jwt decode(String token) throws JwtException {
JWT jwt = parse(token); JWT jwt = parse(token);
if (jwt instanceof PlainJWT) { if (jwt instanceof PlainJWT) {
throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); throw new BadJwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
} }
Jwt createdJwt = createJwt(token, jwt); Jwt createdJwt = createJwt(token, jwt);
return validateJwt(createdJwt); return validateJwt(createdJwt);
@ -130,7 +131,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
try { try {
return JWTParser.parse(token); return JWTParser.parse(token);
} catch (Exception ex) { } catch (Exception ex) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex);
} }
} }
@ -152,11 +153,13 @@ public final class NimbusJwtDecoder implements JwtDecoder {
} else { } else {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex);
} }
} catch (JOSEException ex) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex);
} catch (Exception ex) { } catch (Exception ex) {
if (ex.getCause() instanceof ParseException) { if (ex.getCause() instanceof ParseException) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed payload")); throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed payload"));
} else { } else {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex);
} }
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -136,7 +136,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
public Mono<Jwt> decode(String token) throws JwtException { public Mono<Jwt> decode(String token) throws JwtException {
JWT jwt = parse(token); JWT jwt = parse(token);
if (jwt instanceof PlainJWT) { if (jwt instanceof PlainJWT) {
throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); throw new BadJwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
} }
return this.decode(jwt); return this.decode(jwt);
} }
@ -145,7 +145,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
try { try {
return JWTParser.parse(token); return JWTParser.parse(token);
} catch (Exception ex) { } catch (Exception ex) {
throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); throw new BadJwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
} }
} }
@ -155,19 +155,25 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
.map(set -> createJwt(parsedToken, set)) .map(set -> createJwt(parsedToken, set))
.map(this::validateJwt) .map(this::validateJwt)
.onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e)); .onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
} catch (JwtException ex) {
throw ex;
} catch (RuntimeException ex) { } catch (RuntimeException ex) {
throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
} }
} }
private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) { private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) {
Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject()); try {
Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims());
return Jwt.withTokenValue(parsedJwt.getParsedString()) return Jwt.withTokenValue(parsedJwt.getParsedString())
.headers(h -> h.putAll(headers)) .headers(h -> h.putAll(headers))
.claims(c -> c.putAll(claims)) .claims(c -> c.putAll(claims))
.build(); .build();
} catch (Exception ex) {
throw new BadJwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
}
} }
private Jwt validateJwt(Jwt jwt) { private Jwt validateJwt(Jwt jwt) {
@ -345,7 +351,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
private JWKSelector createSelector(Set<JWSAlgorithm> expectedJwsAlgorithms, Header header) { private JWKSelector createSelector(Set<JWSAlgorithm> expectedJwsAlgorithms, Header header) {
if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) { if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) {
throw new JwtException("Unsupported algorithm of " + header.getAlgorithm()); throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm());
} }
return new JWKSelector(JWKMatcher.forJWSHeader((JWSHeader) header)); return new JWKSelector(JWKMatcher.forJWSHeader((JWSHeader) header));
@ -514,7 +520,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
.collectList() .collectList()
.map(jwks -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwks))); .map(jwks -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwks)));
} }
throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); throw new BadJwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
}; };
} }
} }
@ -524,7 +530,10 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
try { try {
return jwtProcessor.process(parsedToken, context); return jwtProcessor.process(parsedToken, context);
} }
catch (BadJOSEException | JOSEException e) { catch (BadJOSEException e) {
throw new BadJwtException("Failed to validate the token", e);
}
catch (JOSEException e) {
throw new JwtException("Failed to validate the token", e); throw new JwtException("Failed to validate the token", e);
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -47,7 +47,7 @@ final class SingleKeyJWSKeySelector<C extends SecurityContext> implements JWSKey
@Override @Override
public List<? extends Key> selectJWSKeys(JWSHeader header, C context) { public List<? extends Key> selectJWSKeys(JWSHeader header, C context) {
if (!this.expectedJwsAlgorithm.equals(header.getAlgorithm())) { if (!this.expectedJwsAlgorithm.equals(header.getAlgorithm())) {
throw new IllegalArgumentException("Unsupported algorithm of " + header.getAlgorithm()); throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm());
} }
return this.keySet; return this.keySet;
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -132,7 +132,7 @@ public class NimbusJwtDecoderTests {
@Test @Test
public void decodeWhenJwtInvalidThenThrowJwtException() { public void decodeWhenJwtInvalidThenThrowJwtException() {
assertThatThrownBy(() -> this.jwtDecoder.decode("invalid")) assertThatThrownBy(() -> this.jwtDecoder.decode("invalid"))
.isInstanceOf(JwtException.class); .isInstanceOf(BadJwtException.class);
} }
// gh-5168 // gh-5168
@ -152,14 +152,14 @@ public class NimbusJwtDecoderTests {
@Test @Test
public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() { public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() {
assertThatCode(() -> this.jwtDecoder.decode(UNSIGNED_JWT)) assertThatCode(() -> this.jwtDecoder.decode(UNSIGNED_JWT))
.isInstanceOf(JwtException.class) .isInstanceOf(BadJwtException.class)
.hasMessageContaining("Unsupported algorithm of none"); .hasMessageContaining("Unsupported algorithm of none");
} }
@Test @Test
public void decodeWhenJwtIsMalformedThenReturnsStockException() { public void decodeWhenJwtIsMalformedThenReturnsStockException() {
assertThatCode(() -> this.jwtDecoder.decode(MALFORMED_JWT)) assertThatCode(() -> this.jwtDecoder.decode(MALFORMED_JWT))
.isInstanceOf(JwtException.class) .isInstanceOf(BadJwtException.class)
.hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload"); .hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload");
} }
@ -205,6 +205,18 @@ public class NimbusJwtDecoderTests {
assertThat(jwt.getClaims().get("custom")).isEqualTo("value"); assertThat(jwt.getClaims().get("custom")).isEqualTo("value");
} }
// gh-7885
@Test
public void decodeWhenClaimSetConverterFailsThenBadJwtException() {
Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = mock(Converter.class);
this.jwtDecoder.setClaimSetConverter(claimSetConverter);
when(claimSetConverter.convert(any(Map.class))).thenThrow(new IllegalArgumentException("bad conversion"));
assertThatCode(() -> this.jwtDecoder.decode(SIGNED_JWT))
.isInstanceOf(BadJwtException.class);
}
@Test @Test
public void decodeWhenSignedThenOk() { public void decodeWhenSignedThenOk() {
NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(withSigning(JWK_SET)); NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(withSigning(JWK_SET));
@ -217,6 +229,7 @@ public class NimbusJwtDecoderTests {
NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(withSigning(MALFORMED_JWK_SET)); NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(withSigning(MALFORMED_JWK_SET));
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
.isInstanceOf(JwtException.class) .isInstanceOf(JwtException.class)
.isNotInstanceOf(BadJwtException.class)
.hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set"); .hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set");
} }
@ -229,6 +242,7 @@ public class NimbusJwtDecoderTests {
server.shutdown(); server.shutdown();
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
.isInstanceOf(JwtException.class) .isInstanceOf(JwtException.class)
.isNotInstanceOf(BadJwtException.class)
.hasMessageContaining("An error occurred while attempting to decode the Jwt"); .hasMessageContaining("An error occurred while attempting to decode the Jwt");
} }
} }
@ -301,7 +315,7 @@ public class NimbusJwtDecoderTests {
public void decodeWhenSignatureMismatchesAlgorithmThenThrowsException() throws Exception { public void decodeWhenSignatureMismatchesAlgorithmThenThrowsException() throws Exception {
NimbusJwtDecoder decoder = withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build(); NimbusJwtDecoder decoder = withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build();
Assertions.assertThatCode(() -> decoder.decode(RS256_SIGNED_JWT)) Assertions.assertThatCode(() -> decoder.decode(RS256_SIGNED_JWT))
.isInstanceOf(JwtException.class); .isInstanceOf(BadJwtException.class);
} }
@Test @Test
@ -345,7 +359,7 @@ public class NimbusJwtDecoderTests {
SignedJWT signedJWT = signedJwt(secretKey, macAlgorithm, claimsSet); SignedJWT signedJWT = signedJwt(secretKey, macAlgorithm, claimsSet);
NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build(); NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build();
assertThatThrownBy(() -> decoder.decode(signedJWT.serialize())) assertThatThrownBy(() -> decoder.decode(signedJWT.serialize()))
.isInstanceOf(JwtException.class) .isInstanceOf(BadJwtException.class)
.hasMessageContaining("Unsupported algorithm of HS256"); .hasMessageContaining("Unsupported algorithm of HS256");
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -171,7 +171,7 @@ public class NimbusReactiveJwtDecoderTests {
@Test @Test
public void decodeWhenNoPeriodThenFail() { public void decodeWhenNoPeriodThenFail() {
assertThatCode(() -> this.decoder.decode("").block()) assertThatCode(() -> this.decoder.decode("").block())
.isInstanceOf(JwtException.class); .isInstanceOf(BadJwtException.class);
} }
@Test @Test
@ -184,26 +184,26 @@ public class NimbusReactiveJwtDecoderTests {
@Test @Test
public void decodeWhenInvalidSignatureThenFail() { public void decodeWhenInvalidSignatureThenFail() {
assertThatCode(() -> this.decoder.decode(this.messageReadToken.substring(0, this.messageReadToken.length() - 2)).block()) assertThatCode(() -> this.decoder.decode(this.messageReadToken.substring(0, this.messageReadToken.length() - 2)).block())
.isInstanceOf(JwtException.class); .isInstanceOf(BadJwtException.class);
} }
@Test @Test
public void decodeWhenAlgNoneThenFail() { public void decodeWhenAlgNoneThenFail() {
assertThatCode(() -> this.decoder.decode("ew0KICAiYWxnIjogIm5vbmUiLA0KICAidHlwIjogIkpXVCINCn0.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.").block()) assertThatCode(() -> this.decoder.decode("ew0KICAiYWxnIjogIm5vbmUiLA0KICAidHlwIjogIkpXVCINCn0.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.").block())
.isInstanceOf(JwtException.class) .isInstanceOf(BadJwtException.class)
.hasMessage("Unsupported algorithm of none"); .hasMessage("Unsupported algorithm of none");
} }
@Test @Test
public void decodeWhenInvalidAlgMismatchThenFail() { public void decodeWhenInvalidAlgMismatchThenFail() {
assertThatCode(() -> this.decoder.decode("ew0KICAiYWxnIjogIkVTMjU2IiwNCiAgInR5cCI6ICJKV1QiDQp9.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.").block()) assertThatCode(() -> this.decoder.decode("ew0KICAiYWxnIjogIkVTMjU2IiwNCiAgInR5cCI6ICJKV1QiDQp9.ew0KICAic3ViIjogIjEyMzQ1Njc4OTAiLA0KICAibmFtZSI6ICJKb2huIERvZSIsDQogICJpYXQiOiAxNTE2MjM5MDIyDQp9.").block())
.isInstanceOf(JwtException.class); .isInstanceOf(BadJwtException.class);
} }
@Test @Test
public void decodeWhenUnsignedTokenThenMessageDoesNotMentionClass() { public void decodeWhenUnsignedTokenThenMessageDoesNotMentionClass() {
assertThatCode(() -> this.decoder.decode(this.unsignedToken).block()) assertThatCode(() -> this.decoder.decode(this.unsignedToken).block())
.isInstanceOf(JwtException.class) .isInstanceOf(BadJwtException.class)
.hasMessage("Unsupported algorithm of none"); .hasMessage("Unsupported algorithm of none");
} }
@ -217,7 +217,7 @@ public class NimbusReactiveJwtDecoderTests {
when(jwtValidator.validate(any(Jwt.class))).thenReturn(result); when(jwtValidator.validate(any(Jwt.class))).thenReturn(result);
assertThatCode(() -> this.decoder.decode(this.messageReadToken).block()) assertThatCode(() -> this.decoder.decode(this.messageReadToken).block())
.isInstanceOf(JwtException.class) .isInstanceOf(JwtValidationException.class)
.hasMessageContaining("mock-description"); .hasMessageContaining("mock-description");
} }
@ -234,6 +234,18 @@ public class NimbusReactiveJwtDecoderTests {
verify(claimSetConverter).convert(any(Map.class)); verify(claimSetConverter).convert(any(Map.class));
} }
// gh-7885
@Test
public void decodeWhenClaimSetConverterFailsThenBadJwtException() {
Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = mock(Converter.class);
this.decoder.setClaimSetConverter(claimSetConverter);
when(claimSetConverter.convert(any(Map.class))).thenThrow(new IllegalArgumentException("bad conversion"));
assertThatCode(() -> this.decoder.decode(this.messageReadToken).block())
.isInstanceOf(BadJwtException.class);
}
@Test @Test
public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() { public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() {
assertThatCode(() -> this.decoder.setJwtValidator(null)) assertThatCode(() -> this.decoder.setJwtValidator(null))
@ -310,7 +322,7 @@ public class NimbusReactiveJwtDecoderTests {
NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder decoder =
withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build(); withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build();
assertThatCode(() -> decoder.decode(this.rsa256).block()) assertThatCode(() -> decoder.decode(this.rsa256).block())
.isInstanceOf(JwtException.class); .isInstanceOf(BadJwtException.class);
} }
@Test @Test
@ -372,7 +384,7 @@ public class NimbusReactiveJwtDecoderTests {
this.decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build(); this.decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build();
assertThatThrownBy(() -> this.decoder.decode(signedJWT.serialize()).block()) assertThatThrownBy(() -> this.decoder.decode(signedJWT.serialize()).block())
.isInstanceOf(JwtException.class); .isInstanceOf(BadJwtException.class);
} }
@Test @Test

View File

@ -20,9 +20,11 @@ import java.util.Collection;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.JwtException;
@ -80,8 +82,10 @@ public final class JwtAuthenticationProvider implements AuthenticationProvider {
Jwt jwt; Jwt jwt;
try { try {
jwt = this.jwtDecoder.decode(bearer.getToken()); jwt = this.jwtDecoder.decode(bearer.getToken());
} catch (JwtException failed) { } catch (BadJwtException failed) {
throw new InvalidBearerTokenException(failed.getMessage(), failed); throw new InvalidBearerTokenException(failed.getMessage(), failed);
} catch (JwtException failed) {
throw new AuthenticationServiceException(failed.getMessage(), failed);
} }
AbstractAuthenticationToken token = this.jwtAuthenticationConverter.convert(jwt); AbstractAuthenticationToken token = this.jwtAuthenticationConverter.convert(jwt);

View File

@ -20,9 +20,11 @@ import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
@ -71,7 +73,11 @@ public final class JwtReactiveAuthenticationManager implements ReactiveAuthentic
this.jwtAuthenticationConverter = jwtAuthenticationConverter; this.jwtAuthenticationConverter = jwtAuthenticationConverter;
} }
private OAuth2AuthenticationException onError(JwtException e) { private AuthenticationException onError(JwtException e) {
return new InvalidBearerTokenException(e.getMessage(), e); if (e instanceof BadJwtException) {
return new InvalidBearerTokenException(e.getMessage(), e);
} else {
return new AuthenticationServiceException(e.getMessage(), e);
}
} }
} }

View File

@ -24,7 +24,9 @@ import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.JwtException;
@ -78,7 +80,7 @@ public class JwtAuthenticationProviderTests {
public void authenticateWhenJwtDecodeFailsThenRespondsWithInvalidToken() { public void authenticateWhenJwtDecodeFailsThenRespondsWithInvalidToken() {
BearerTokenAuthenticationToken token = this.authentication(); BearerTokenAuthenticationToken token = this.authentication();
when(this.jwtDecoder.decode("token")).thenThrow(JwtException.class); when(this.jwtDecoder.decode("token")).thenThrow(BadJwtException.class);
assertThatCode(() -> this.provider.authenticate(token)) assertThatCode(() -> this.provider.authenticate(token))
.matches(failed -> failed instanceof OAuth2AuthenticationException) .matches(failed -> failed instanceof OAuth2AuthenticationException)
@ -89,7 +91,7 @@ public class JwtAuthenticationProviderTests {
public void authenticateWhenDecoderThrowsIncompatibleErrorMessageThenWrapsWithGenericOne() { public void authenticateWhenDecoderThrowsIncompatibleErrorMessageThenWrapsWithGenericOne() {
BearerTokenAuthenticationToken token = this.authentication(); BearerTokenAuthenticationToken token = this.authentication();
when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("with \"invalid\" chars")); when(this.jwtDecoder.decode(token.getToken())).thenThrow(new BadJwtException("with \"invalid\" chars"));
assertThatCode(() -> this.provider.authenticate(token)) assertThatCode(() -> this.provider.authenticate(token))
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -98,6 +100,18 @@ public class JwtAuthenticationProviderTests {
"Invalid token"); "Invalid token");
} }
// gh-7785
@Test
public void authenticateWhenDecoderFailsGenericallyThenThrowsGenericException() {
BearerTokenAuthenticationToken token = this.authentication();
when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("no jwk set"));
assertThatCode(() -> this.provider.authenticate(token))
.isInstanceOf(AuthenticationException.class)
.isNotInstanceOf(OAuth2AuthenticationException.class);
}
@Test @Test
public void authenticateWhenConverterReturnsAuthenticationThenProviderPropagatesIt() { public void authenticateWhenConverterReturnsAuthenticationThenProviderPropagatesIt() {
BearerTokenAuthenticationToken token = this.authentication(); BearerTokenAuthenticationToken token = this.authentication();

View File

@ -25,8 +25,10 @@ import reactor.core.publisher.Mono;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
@ -82,7 +84,7 @@ public class JwtReactiveAuthenticationManagerTests {
@Test @Test
public void authenticateWhenJwtExceptionThenOAuth2AuthenticationException() { public void authenticateWhenJwtExceptionThenOAuth2AuthenticationException() {
BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1");
when(this.jwtDecoder.decode(any())).thenReturn(Mono.error(new JwtException("Oops"))); when(this.jwtDecoder.decode(any())).thenReturn(Mono.error(new BadJwtException("Oops")));
assertThatCode(() -> this.manager.authenticate(token).block()) assertThatCode(() -> this.manager.authenticate(token).block())
.isInstanceOf(OAuth2AuthenticationException.class); .isInstanceOf(OAuth2AuthenticationException.class);
@ -92,7 +94,7 @@ public class JwtReactiveAuthenticationManagerTests {
@Test @Test
public void authenticateWhenDecoderThrowsIncompatibleErrorMessageThenWrapsWithGenericOne() { public void authenticateWhenDecoderThrowsIncompatibleErrorMessageThenWrapsWithGenericOne() {
BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1");
when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("with \"invalid\" chars")); when(this.jwtDecoder.decode(token.getToken())).thenThrow(new BadJwtException("with \"invalid\" chars"));
assertThatCode(() -> this.manager.authenticate(token).block()) assertThatCode(() -> this.manager.authenticate(token).block())
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
@ -101,6 +103,17 @@ public class JwtReactiveAuthenticationManagerTests {
"Invalid token"); "Invalid token");
} }
// gh-7785
@Test
public void authenticateWhenDecoderFailsGenericallyThenThrowsGenericException() {
BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1");
when(this.jwtDecoder.decode(token.getToken())).thenThrow(new JwtException("no jwk set"));
assertThatCode(() -> this.manager.authenticate(token).block())
.isInstanceOf(AuthenticationException.class)
.isNotInstanceOf(OAuth2AuthenticationException.class);
}
@Test @Test
public void authenticateWhenNotJwtExceptionThenPropagates() { public void authenticateWhenNotJwtExceptionThenPropagates() {
BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1"); BearerTokenAuthenticationToken token = new BearerTokenAuthenticationToken("token-1");