Remove SignedJWT Check

JWTProcessor already does sufficient checking to confirm that the JWT
is of the appropriate type.

Fixes: gh-7034
This commit is contained in:
Josh Cummings 2019-06-25 16:49:29 -06:00
parent d2248d185b
commit 37d108ccc2
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 54 additions and 46 deletions

View File

@ -16,6 +16,17 @@
package org.springframework.security.oauth2.jwt; package org.springframework.security.oauth2.jwt;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.time.Instant;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import javax.crypto.SecretKey;
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.JWKSet;
@ -32,10 +43,11 @@ import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser; import com.nimbusds.jwt.JWTParser;
import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.PlainJWT;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
@ -51,17 +63,6 @@ import org.springframework.util.Assert;
import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import javax.crypto.SecretKey;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.time.Instant;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
/** /**
* A low-level Nimbus implementation of {@link JwtDecoder} which takes a raw Nimbus configuration. * A low-level Nimbus implementation of {@link JwtDecoder} which takes a raw Nimbus configuration.
* *
@ -119,11 +120,11 @@ public final class NimbusJwtDecoder implements JwtDecoder {
@Override @Override
public Jwt decode(String token) throws JwtException { public Jwt decode(String token) throws JwtException {
JWT jwt = parse(token); JWT jwt = parse(token);
if (jwt instanceof SignedJWT) { if (jwt instanceof PlainJWT) {
Jwt createdJwt = createJwt(token, jwt); throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
return validateJwt(createdJwt);
} }
throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); Jwt createdJwt = createJwt(token, jwt);
return validateJwt(createdJwt);
} }
private JWT parse(String token) { private JWT parse(String token) {

View File

@ -15,6 +15,15 @@
*/ */
package org.springframework.security.oauth2.jwt; package org.springframework.security.oauth2.jwt;
import java.security.interfaces.RSAPublicKey;
import java.time.Instant;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;
import javax.crypto.SecretKey;
import com.nimbusds.jose.Header;
import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSHeader;
@ -35,9 +44,13 @@ import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser; import com.nimbusds.jwt.JWTParser;
import com.nimbusds.jwt.PlainJWT;
import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
@ -46,16 +59,6 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import javax.crypto.SecretKey;
import java.security.interfaces.RSAPublicKey;
import java.time.Instant;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;
/** /**
* An implementation of a {@link ReactiveJwtDecoder} that "decodes" a * An implementation of a {@link ReactiveJwtDecoder} that "decodes" a
@ -75,7 +78,7 @@ import java.util.function.Function;
* @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus JOSE + JWT SDK</a> * @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus JOSE + JWT SDK</a>
*/ */
public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
private final Converter<SignedJWT, Mono<JWTClaimsSet>> jwtProcessor; private final Converter<JWT, Mono<JWTClaimsSet>> jwtProcessor;
private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault(); private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter =
@ -106,7 +109,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
* @param jwtProcessor the {@link Converter} used to process and verify the signed Jwt and return the Jwt Claim Set * @param jwtProcessor the {@link Converter} used to process and verify the signed Jwt and return the Jwt Claim Set
* @since 5.2 * @since 5.2
*/ */
public NimbusReactiveJwtDecoder(Converter<SignedJWT, Mono<JWTClaimsSet>> jwtProcessor) { public NimbusReactiveJwtDecoder(Converter<JWT, Mono<JWTClaimsSet>> jwtProcessor) {
this.jwtProcessor = jwtProcessor; this.jwtProcessor = jwtProcessor;
} }
@ -133,10 +136,10 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
@Override @Override
public Mono<Jwt> decode(String token) throws JwtException { public Mono<Jwt> decode(String token) throws JwtException {
JWT jwt = parse(token); JWT jwt = parse(token);
if (jwt instanceof SignedJWT) { if (jwt instanceof PlainJWT) {
return this.decode((SignedJWT) jwt); throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
} }
throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); return this.decode(jwt);
} }
private JWT parse(String token) { private JWT parse(String token) {
@ -147,7 +150,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
} }
} }
private Mono<Jwt> decode(SignedJWT parsedToken) { private Mono<Jwt> decode(JWT parsedToken) {
try { try {
return this.jwtProcessor.convert(parsedToken) return this.jwtProcessor.convert(parsedToken)
.map(set -> createJwt(parsedToken, set)) .map(set -> createJwt(parsedToken, set))
@ -280,7 +283,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
return new NimbusReactiveJwtDecoder(processor()); return new NimbusReactiveJwtDecoder(processor());
} }
Converter<SignedJWT, Mono<JWTClaimsSet>> processor() { Converter<JWT, Mono<JWTClaimsSet>> processor() {
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet(); JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
JWSKeySelector<JWKSecurityContext> jwsKeySelector = JWSKeySelector<JWKSecurityContext> jwsKeySelector =
@ -292,20 +295,20 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri); ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
source.setWebClient(this.webClient); source.setWebClient(this.webClient);
return signedJWT -> { return jwt -> {
JWKSelector selector = createSelector(signedJWT.getHeader()); JWKSelector selector = createSelector(jwt.getHeader());
return source.get(selector) return source.get(selector)
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e)) .onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
.map(jwkList -> createClaimsSet(jwtProcessor, signedJWT, new JWKSecurityContext(jwkList))); .map(jwkList -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList)));
}; };
} }
private JWKSelector createSelector(JWSHeader header) { private JWKSelector createSelector(Header header) {
if (!this.jwsAlgorithm.equals(header.getAlgorithm())) { if (!this.jwsAlgorithm.equals(header.getAlgorithm())) {
throw new JwtException("Unsupported algorithm of " + header.getAlgorithm()); throw new JwtException("Unsupported algorithm of " + header.getAlgorithm());
} }
return new JWKSelector(JWKMatcher.forJWSHeader(header)); return new JWKSelector(JWKMatcher.forJWSHeader((JWSHeader) header));
} }
} }
@ -353,7 +356,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
return new NimbusReactiveJwtDecoder(processor()); return new NimbusReactiveJwtDecoder(processor());
} }
Converter<SignedJWT, Mono<JWTClaimsSet>> processor() { Converter<JWT, Mono<JWTClaimsSet>> processor() {
if (!JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm)) { if (!JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm)) {
throw new IllegalStateException("The provided key is of type RSA; " + throw new IllegalStateException("The provided key is of type RSA; " +
"however the signature algorithm is of some other type: " + "however the signature algorithm is of some other type: " +
@ -370,7 +373,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
// Spring Security validates the claim set independent from Nimbus // Spring Security validates the claim set independent from Nimbus
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { });
return signedJWT -> Mono.just(signedJWT).map(jwt -> createClaimsSet(jwtProcessor, jwt, null)); return jwt -> Mono.just(createClaimsSet(jwtProcessor, jwt, null));
} }
} }
@ -414,7 +417,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
return new NimbusReactiveJwtDecoder(processor()); return new NimbusReactiveJwtDecoder(processor());
} }
Converter<SignedJWT, Mono<JWTClaimsSet>> processor() { Converter<JWT, Mono<JWTClaimsSet>> processor() {
JWKSource<SecurityContext> jwkSource = new ImmutableSecret<>(this.secretKey); JWKSource<SecurityContext> jwkSource = new ImmutableSecret<>(this.secretKey);
JWSKeySelector<SecurityContext> jwsKeySelector = JWSKeySelector<SecurityContext> jwsKeySelector =
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource); new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
@ -424,7 +427,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
// Spring Security validates the claim set independent from Nimbus // Spring Security validates the claim set independent from Nimbus
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { }); jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { });
return signedJWT -> Mono.just(signedJWT).map(jwt -> createClaimsSet(jwtProcessor, jwt, null)); return jwt -> Mono.just(createClaimsSet(jwtProcessor, jwt, null));
} }
} }
@ -464,7 +467,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
return new NimbusReactiveJwtDecoder(processor()); return new NimbusReactiveJwtDecoder(processor());
} }
Converter<SignedJWT, Mono<JWTClaimsSet>> processor() { Converter<JWT, Mono<JWTClaimsSet>> processor() {
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet(); JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
JWSKeySelector<JWKSecurityContext> jwsKeySelector = JWSKeySelector<JWKSecurityContext> jwsKeySelector =
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource); new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
@ -472,11 +475,15 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
jwtProcessor.setJWSKeySelector(jwsKeySelector); jwtProcessor.setJWSKeySelector(jwsKeySelector);
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {}); jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
return signedJWT -> return jwt -> {
this.jwkSource.apply(signedJWT) if (jwt instanceof SignedJWT) {
return this.jwkSource.apply((SignedJWT) jwt)
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e)) .onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
.collectList() .collectList()
.map(jwks -> createClaimsSet(jwtProcessor, signedJWT, new JWKSecurityContext(jwks))); .map(jwks -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwks)));
}
throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
};
} }
} }