diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 0ffcfe4bb8..a2b784eee4 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -16,6 +16,17 @@ package org.springframework.security.oauth2.jwt; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.security.interfaces.RSAPublicKey; +import java.text.ParseException; +import java.time.Instant; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import javax.crypto.SecretKey; + import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.jwk.JWKSet; @@ -32,10 +43,11 @@ import com.nimbusds.jose.util.ResourceRetriever; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTParser; -import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.jwt.PlainJWT; import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -51,17 +63,6 @@ import org.springframework.util.Assert; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import javax.crypto.SecretKey; -import java.io.IOException; -import java.net.MalformedURLException; -import java.net.URL; -import java.security.interfaces.RSAPublicKey; -import java.text.ParseException; -import java.time.Instant; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.Map; - /** * A low-level Nimbus implementation of {@link JwtDecoder} which takes a raw Nimbus configuration. * @@ -119,11 +120,11 @@ public final class NimbusJwtDecoder implements JwtDecoder { @Override public Jwt decode(String token) throws JwtException { JWT jwt = parse(token); - if (jwt instanceof SignedJWT) { - Jwt createdJwt = createJwt(token, jwt); - return validateJwt(createdJwt); + if (jwt instanceof PlainJWT) { + throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); } - throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); + Jwt createdJwt = createJwt(token, jwt); + return validateJwt(createdJwt); } private JWT parse(String token) { diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 03af89d1e0..0cf0fe02fe 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -15,6 +15,15 @@ */ package org.springframework.security.oauth2.jwt; +import java.security.interfaces.RSAPublicKey; +import java.time.Instant; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.Function; +import javax.crypto.SecretKey; + +import com.nimbusds.jose.Header; import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; @@ -35,9 +44,13 @@ import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTParser; +import com.nimbusds.jwt.PlainJWT; import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; @@ -46,16 +59,6 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import javax.crypto.SecretKey; -import java.security.interfaces.RSAPublicKey; -import java.time.Instant; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.function.Function; /** * An implementation of a {@link ReactiveJwtDecoder} that "decodes" a @@ -75,7 +78,7 @@ import java.util.function.Function; * @see Nimbus JOSE + JWT SDK */ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { - private final Converter> jwtProcessor; + private final Converter> jwtProcessor; private OAuth2TokenValidator jwtValidator = JwtValidators.createDefault(); private Converter, Map> claimSetConverter = @@ -106,7 +109,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { * @param jwtProcessor the {@link Converter} used to process and verify the signed Jwt and return the Jwt Claim Set * @since 5.2 */ - public NimbusReactiveJwtDecoder(Converter> jwtProcessor) { + public NimbusReactiveJwtDecoder(Converter> jwtProcessor) { this.jwtProcessor = jwtProcessor; } @@ -133,10 +136,10 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { @Override public Mono decode(String token) throws JwtException { JWT jwt = parse(token); - if (jwt instanceof SignedJWT) { - return this.decode((SignedJWT) jwt); + if (jwt instanceof PlainJWT) { + throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); } - throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); + return this.decode(jwt); } private JWT parse(String token) { @@ -147,7 +150,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } } - private Mono decode(SignedJWT parsedToken) { + private Mono decode(JWT parsedToken) { try { return this.jwtProcessor.convert(parsedToken) .map(set -> createJwt(parsedToken, set)) @@ -280,7 +283,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return new NimbusReactiveJwtDecoder(processor()); } - Converter> processor() { + Converter> processor() { JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet(); JWSKeySelector jwsKeySelector = @@ -292,20 +295,20 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri); source.setWebClient(this.webClient); - return signedJWT -> { - JWKSelector selector = createSelector(signedJWT.getHeader()); + return jwt -> { + JWKSelector selector = createSelector(jwt.getHeader()); return source.get(selector) .onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e)) - .map(jwkList -> createClaimsSet(jwtProcessor, signedJWT, new JWKSecurityContext(jwkList))); + .map(jwkList -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList))); }; } - private JWKSelector createSelector(JWSHeader header) { + private JWKSelector createSelector(Header header) { if (!this.jwsAlgorithm.equals(header.getAlgorithm())) { throw new JwtException("Unsupported algorithm of " + header.getAlgorithm()); } - return new JWKSelector(JWKMatcher.forJWSHeader(header)); + return new JWKSelector(JWKMatcher.forJWSHeader((JWSHeader) header)); } } @@ -353,7 +356,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return new NimbusReactiveJwtDecoder(processor()); } - Converter> processor() { + Converter> processor() { if (!JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm)) { throw new IllegalStateException("The provided key is of type RSA; " + "however the signature algorithm is of some other type: " + @@ -370,7 +373,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { // Spring Security validates the claim set independent from Nimbus jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); - return signedJWT -> Mono.just(signedJWT).map(jwt -> createClaimsSet(jwtProcessor, jwt, null)); + return jwt -> Mono.just(createClaimsSet(jwtProcessor, jwt, null)); } } @@ -414,7 +417,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return new NimbusReactiveJwtDecoder(processor()); } - Converter> processor() { + Converter> processor() { JWKSource jwkSource = new ImmutableSecret<>(this.secretKey); JWSKeySelector jwsKeySelector = new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource); @@ -424,7 +427,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { // Spring Security validates the claim set independent from Nimbus jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); - return signedJWT -> Mono.just(signedJWT).map(jwt -> createClaimsSet(jwtProcessor, jwt, null)); + return jwt -> Mono.just(createClaimsSet(jwtProcessor, jwt, null)); } } @@ -464,7 +467,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { return new NimbusReactiveJwtDecoder(processor()); } - Converter> processor() { + Converter> processor() { JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet(); JWSKeySelector jwsKeySelector = new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource); @@ -472,11 +475,15 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { jwtProcessor.setJWSKeySelector(jwsKeySelector); jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {}); - return signedJWT -> - this.jwkSource.apply(signedJWT) + return jwt -> { + if (jwt instanceof SignedJWT) { + return this.jwkSource.apply((SignedJWT) jwt) .onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e)) .collectList() - .map(jwks -> createClaimsSet(jwtProcessor, signedJWT, new JWKSecurityContext(jwks))); + .map(jwks -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwks))); + } + throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); + }; } }