Polish spring-security-oauth2-jose main code

Manually polish `spring-security-oauth-jose` following the
formatting and checkstyle fixes.

Issue gh-8945
This commit is contained in:
Phillip Webb 2020-07-31 22:16:50 -07:00 committed by Rob Winch
parent a577871bca
commit 20aa8bef25
15 changed files with 96 additions and 175 deletions

View File

@ -55,6 +55,15 @@ public enum MacAlgorithm implements JwsAlgorithm {
this.name = name; this.name = name;
} }
/**
* Returns the algorithm name.
* @return the algorithm name
*/
@Override
public String getName() {
return this.name;
}
/** /**
* Attempt to resolve the provided algorithm name to a {@code MacAlgorithm}. * Attempt to resolve the provided algorithm name to a {@code MacAlgorithm}.
* @param name the algorithm name * @param name the algorithm name
@ -69,13 +78,4 @@ public enum MacAlgorithm implements JwsAlgorithm {
return null; return null;
} }
/**
* Returns the algorithm name.
* @return the algorithm name
*/
@Override
public String getName() {
return this.name;
}
} }

View File

@ -85,6 +85,15 @@ public enum SignatureAlgorithm implements JwsAlgorithm {
this.name = name; this.name = name;
} }
/**
* Returns the algorithm name.
* @return the algorithm name
*/
@Override
public String getName() {
return this.name;
}
/** /**
* Attempt to resolve the provided algorithm name to a {@code SignatureAlgorithm}. * Attempt to resolve the provided algorithm name to a {@code SignatureAlgorithm}.
* @param name the algorithm name * @param name the algorithm name
@ -99,13 +108,4 @@ public enum SignatureAlgorithm implements JwsAlgorithm {
return null; return null;
} }
/**
* Returns the algorithm name.
* @return the algorithm name
*/
@Override
public String getName() {
return this.name;
}
} }

View File

@ -58,9 +58,6 @@ public final class JwtClaimValidator<T> implements OAuth2TokenValidator<Jwt> {
"https://tools.ietf.org/html/rfc6750#section-3.1"); "https://tools.ietf.org/html/rfc6750#section-3.1");
} }
/**
* {@inheritDoc}
*/
@Override @Override
public OAuth2TokenValidatorResult validate(Jwt token) { public OAuth2TokenValidatorResult validate(Jwt token) {
Assert.notNull(token, "token cannot be null"); Assert.notNull(token, "token cannot be null");
@ -68,10 +65,8 @@ public final class JwtClaimValidator<T> implements OAuth2TokenValidator<Jwt> {
if (this.test.test(claimValue)) { if (this.test.test(claimValue)) {
return OAuth2TokenValidatorResult.success(); return OAuth2TokenValidatorResult.success();
} }
else {
this.logger.debug(this.error.getDescription()); this.logger.debug(this.error.getDescription());
return OAuth2TokenValidatorResult.failure(this.error); return OAuth2TokenValidatorResult.failure(this.error);
} }
}
} }

View File

@ -23,6 +23,7 @@ import java.util.Map;
import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.RequestEntity; import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.web.client.HttpClientErrorException; import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
@ -46,7 +47,7 @@ final class JwtDecoderProviderConfigurationUtils {
private static final RestTemplate rest = new RestTemplate(); private static final RestTemplate rest = new RestTemplate();
private static final ParameterizedTypeReference<Map<String, Object>> typeReference = new ParameterizedTypeReference<Map<String, Object>>() { private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<Map<String, Object>>() {
}; };
private JwtDecoderProviderConfigurationUtils() { private JwtDecoderProviderConfigurationUtils() {
@ -62,14 +63,16 @@ final class JwtDecoderProviderConfigurationUtils {
} }
static void validateIssuer(Map<String, Object> configuration, String issuer) { static void validateIssuer(Map<String, Object> configuration, String issuer) {
String metadataIssuer = "(unavailable)"; String metadataIssuer = getMetadataIssuer(configuration);
if (configuration.containsKey("issuer")) { Assert.state(issuer.equals(metadataIssuer), () -> "The Issuer \"" + metadataIssuer
metadataIssuer = configuration.get("issuer").toString();
}
if (!issuer.equals(metadataIssuer)) {
throw new IllegalStateException("The Issuer \"" + metadataIssuer
+ "\" provided in the configuration did not " + "match the requested issuer \"" + issuer + "\""); + "\" provided in the configuration did not " + "match the requested issuer \"" + issuer + "\"");
} }
private static String getMetadataIssuer(Map<String, Object> configuration) {
if (configuration.containsKey("issuer")) {
return configuration.get("issuer").toString();
}
return "(unavailable)";
} }
private static Map<String, Object> getConfiguration(String issuer, URI... uris) { private static Map<String, Object> getConfiguration(String issuer, URI... uris) {
@ -77,13 +80,9 @@ final class JwtDecoderProviderConfigurationUtils {
for (URI uri : uris) { for (URI uri : uris) {
try { try {
RequestEntity<Void> request = RequestEntity.get(uri).build(); RequestEntity<Void> request = RequestEntity.get(uri).build();
ResponseEntity<Map<String, Object>> response = rest.exchange(request, typeReference); ResponseEntity<Map<String, Object>> response = rest.exchange(request, STRING_OBJECT_MAP);
Map<String, Object> configuration = response.getBody(); Map<String, Object> configuration = response.getBody();
Assert.isTrue(configuration.get("jwks_uri") != null, "The public JWK set URI must not be null");
if (configuration.get("jwks_uri") == null) {
throw new IllegalArgumentException("The public JWK set URI must not be null");
}
return configuration; return configuration;
} }
catch (IllegalArgumentException ex) { catch (IllegalArgumentException ex) {

View File

@ -34,6 +34,9 @@ import org.springframework.util.Assert;
*/ */
public final class JwtDecoders { public final class JwtDecoders {
private JwtDecoders() {
}
/** /**
* Creates a {@link JwtDecoder} using the provided <a href= * Creates a {@link JwtDecoder} using the provided <a href=
* "https://openid.net/specs/openid-connect-core-1_0.html#IssuerIdentifier">Issuer</a> * "https://openid.net/specs/openid-connect-core-1_0.html#IssuerIdentifier">Issuer</a>
@ -105,11 +108,7 @@ public final class JwtDecoders {
OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefaultWithIssuer(issuer); OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefaultWithIssuer(issuer);
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(configuration.get("jwks_uri").toString()).build(); NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(configuration.get("jwks_uri").toString()).build();
jwtDecoder.setJwtValidator(jwtValidator); jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder; return jwtDecoder;
} }
private JwtDecoders() {
}
} }

View File

@ -39,9 +39,6 @@ public final class JwtIssuerValidator implements OAuth2TokenValidator<Jwt> {
this.validator = new JwtClaimValidator(JwtClaimNames.ISS, issuer::equals); this.validator = new JwtClaimValidator(JwtClaimNames.ISS, issuer::equals);
} }
/**
* {@inheritDoc}
*/
@Override @Override
public OAuth2TokenValidatorResult validate(Jwt token) { public OAuth2TokenValidatorResult validate(Jwt token) {
Assert.notNull(token, "token cannot be null"); Assert.notNull(token, "token cannot be null");

View File

@ -68,31 +68,23 @@ public final class JwtTimestampValidator implements OAuth2TokenValidator<Jwt> {
this.clockSkew = clockSkew; this.clockSkew = clockSkew;
} }
/**
* {@inheritDoc}
*/
@Override @Override
public OAuth2TokenValidatorResult validate(Jwt jwt) { public OAuth2TokenValidatorResult validate(Jwt jwt) {
Assert.notNull(jwt, "jwt cannot be null"); Assert.notNull(jwt, "jwt cannot be null");
Instant expiry = jwt.getExpiresAt(); Instant expiry = jwt.getExpiresAt();
if (expiry != null) { if (expiry != null) {
if (Instant.now(this.clock).minus(this.clockSkew).isAfter(expiry)) { if (Instant.now(this.clock).minus(this.clockSkew).isAfter(expiry)) {
OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt expired at %s", jwt.getExpiresAt())); OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt expired at %s", jwt.getExpiresAt()));
return OAuth2TokenValidatorResult.failure(oAuth2Error); return OAuth2TokenValidatorResult.failure(oAuth2Error);
} }
} }
Instant notBefore = jwt.getNotBefore(); Instant notBefore = jwt.getNotBefore();
if (notBefore != null) { if (notBefore != null) {
if (Instant.now(this.clock).plus(this.clockSkew).isBefore(notBefore)) { if (Instant.now(this.clock).plus(this.clockSkew).isBefore(notBefore)) {
OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt used before %s", jwt.getNotBefore())); OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt used before %s", jwt.getNotBefore()));
return OAuth2TokenValidatorResult.failure(oAuth2Error); return OAuth2TokenValidatorResult.failure(oAuth2Error);
} }
} }
return OAuth2TokenValidatorResult.success(); return OAuth2TokenValidatorResult.success();
} }
@ -103,8 +95,7 @@ public final class JwtTimestampValidator implements OAuth2TokenValidator<Jwt> {
} }
/** /**
* ' Use this {@link Clock} with {@link Instant#now()} for assessing timestamp * Use this {@link Clock} with {@link Instant#now()} for assessing timestamp validity
* validity
* @param clock * @param clock
*/ */
public void setClock(Clock clock) { public void setClock(Clock clock) {

View File

@ -54,7 +54,6 @@ public class JwtValidationException extends BadJwtException {
*/ */
public JwtValidationException(String message, Collection<OAuth2Error> errors) { public JwtValidationException(String message, Collection<OAuth2Error> errors) {
super(message); super(message);
Assert.notEmpty(errors, "errors cannot be empty"); Assert.notEmpty(errors, "errors cannot be empty");
this.errors = new ArrayList<>(errors); this.errors = new ArrayList<>(errors);
} }

View File

@ -32,6 +32,9 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidator;
*/ */
public final class JwtValidators { public final class JwtValidators {
private JwtValidators() {
}
/** /**
* <p> * <p>
* Create a {@link Jwt} Validator that contains all standard validators when an issuer * Create a {@link Jwt} Validator that contains all standard validators when an issuer
@ -69,7 +72,4 @@ public final class JwtValidators {
return new DelegatingOAuth2TokenValidator<>(Arrays.asList(new JwtTimestampValidator())); return new DelegatingOAuth2TokenValidator<>(Arrays.asList(new JwtTimestampValidator()));
} }
private JwtValidators() {
}
} }

View File

@ -93,11 +93,9 @@ public final class MappedJwtClaimSetConverter implements Converter<Map<String, O
*/ */
public static MappedJwtClaimSetConverter withDefaults(Map<String, Converter<Object, ?>> claimTypeConverters) { public static MappedJwtClaimSetConverter withDefaults(Map<String, Converter<Object, ?>> claimTypeConverters) {
Assert.notNull(claimTypeConverters, "claimTypeConverters cannot be null"); Assert.notNull(claimTypeConverters, "claimTypeConverters cannot be null");
Converter<Object, ?> stringConverter = getConverter(STRING_TYPE_DESCRIPTOR); Converter<Object, ?> stringConverter = getConverter(STRING_TYPE_DESCRIPTOR);
Converter<Object, ?> collectionStringConverter = getConverter( Converter<Object, ?> collectionStringConverter = getConverter(
TypeDescriptor.collection(Collection.class, STRING_TYPE_DESCRIPTOR)); TypeDescriptor.collection(Collection.class, STRING_TYPE_DESCRIPTOR));
Map<String, Converter<Object, ?>> claimNameToConverter = new HashMap<>(); Map<String, Converter<Object, ?>> claimNameToConverter = new HashMap<>();
claimNameToConverter.put(JwtClaimNames.AUD, collectionStringConverter); claimNameToConverter.put(JwtClaimNames.AUD, collectionStringConverter);
claimNameToConverter.put(JwtClaimNames.EXP, MappedJwtClaimSetConverter::convertInstant); claimNameToConverter.put(JwtClaimNames.EXP, MappedJwtClaimSetConverter::convertInstant);
@ -107,7 +105,6 @@ public final class MappedJwtClaimSetConverter implements Converter<Map<String, O
claimNameToConverter.put(JwtClaimNames.NBF, MappedJwtClaimSetConverter::convertInstant); claimNameToConverter.put(JwtClaimNames.NBF, MappedJwtClaimSetConverter::convertInstant);
claimNameToConverter.put(JwtClaimNames.SUB, stringConverter); claimNameToConverter.put(JwtClaimNames.SUB, stringConverter);
claimNameToConverter.putAll(claimTypeConverters); claimNameToConverter.putAll(claimTypeConverters);
return new MappedJwtClaimSetConverter(claimNameToConverter); return new MappedJwtClaimSetConverter(claimNameToConverter);
} }
@ -120,9 +117,7 @@ public final class MappedJwtClaimSetConverter implements Converter<Map<String, O
return null; return null;
} }
Instant result = (Instant) CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, INSTANT_TYPE_DESCRIPTOR); Instant result = (Instant) CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, INSTANT_TYPE_DESCRIPTOR);
if (result == null) { Assert.state(result != null, () -> "Could not coerce " + source + " into an Instant");
throw new IllegalStateException("Could not coerce " + source + " into an Instant");
}
return result; return result;
} }
@ -145,24 +140,17 @@ public final class MappedJwtClaimSetConverter implements Converter<Map<String, O
return (String) CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, STRING_TYPE_DESCRIPTOR); return (String) CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, STRING_TYPE_DESCRIPTOR);
} }
/**
* {@inheritDoc}
*/
@Override @Override
public Map<String, Object> convert(Map<String, Object> claims) { public Map<String, Object> convert(Map<String, Object> claims) {
Assert.notNull(claims, "claims cannot be null"); Assert.notNull(claims, "claims cannot be null");
Map<String, Object> mappedClaims = this.delegate.convert(claims); Map<String, Object> mappedClaims = this.delegate.convert(claims);
mappedClaims = removeClaims(mappedClaims); mappedClaims = removeClaims(mappedClaims);
mappedClaims = addClaims(mappedClaims); mappedClaims = addClaims(mappedClaims);
Instant issuedAt = (Instant) mappedClaims.get(JwtClaimNames.IAT); Instant issuedAt = (Instant) mappedClaims.get(JwtClaimNames.IAT);
Instant expiresAt = (Instant) mappedClaims.get(JwtClaimNames.EXP); Instant expiresAt = (Instant) mappedClaims.get(JwtClaimNames.EXP);
if (issuedAt == null && expiresAt != null) { if (issuedAt == null && expiresAt != null) {
mappedClaims.put(JwtClaimNames.IAT, expiresAt.minusSeconds(1)); mappedClaims.put(JwtClaimNames.IAT, expiresAt.minusSeconds(1));
} }
return mappedClaims; return mappedClaims;
} }

View File

@ -145,20 +145,16 @@ public final class NimbusJwtDecoder implements JwtDecoder {
try { try {
// Verify the signature // Verify the signature
JWTClaimsSet jwtClaimsSet = this.jwtProcessor.process(parsedJwt, null); JWTClaimsSet jwtClaimsSet = this.jwtProcessor.process(parsedJwt, null);
Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject()); Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims());
return Jwt.withTokenValue(token).headers((h) -> h.putAll(headers)).claims((c) -> c.putAll(claims)).build(); return Jwt.withTokenValue(token).headers((h) -> h.putAll(headers)).claims((c) -> c.putAll(claims)).build();
} }
catch (RemoteKeySourceException ex) { catch (RemoteKeySourceException ex) {
if (ex.getCause() instanceof ParseException) { if (ex.getCause() instanceof ParseException) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set")); throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"));
} }
else {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex);
} }
}
catch (JOSEException ex) { catch (JOSEException ex) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex);
} }
@ -166,28 +162,27 @@ public final class NimbusJwtDecoder implements JwtDecoder {
if (ex.getCause() instanceof ParseException) { if (ex.getCause() instanceof ParseException) {
throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed payload")); throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed payload"));
} }
else {
throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); throw new BadJwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex);
} }
} }
}
private Jwt validateJwt(Jwt jwt) { private Jwt validateJwt(Jwt jwt) {
OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt); OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt);
if (result.hasErrors()) { if (result.hasErrors()) {
Collection<OAuth2Error> errors = result.getErrors(); Collection<OAuth2Error> errors = result.getErrors();
String validationErrorString = "Unable to validate Jwt"; String validationErrorString = getJwtValidationExceptionMessage(errors);
for (OAuth2Error oAuth2Error : errors) { throw new JwtValidationException(validationErrorString, errors);
if (!StringUtils.isEmpty(oAuth2Error.getDescription())) {
validationErrorString = String.format(DECODING_ERROR_MESSAGE_TEMPLATE,
oAuth2Error.getDescription());
break;
} }
} return jwt;
throw new JwtValidationException(validationErrorString, result.getErrors());
} }
return jwt; private String getJwtValidationExceptionMessage(Collection<OAuth2Error> errors) {
for (OAuth2Error oAuth2Error : errors) {
if (!StringUtils.isEmpty(oAuth2Error.getDescription())) {
return String.format(DECODING_ERROR_MESSAGE_TEMPLATE, oAuth2Error.getDescription());
}
}
return "Unable to validate Jwt";
} }
/** /**
@ -316,7 +311,6 @@ public final class NimbusJwtDecoder implements JwtDecoder {
if (this.signatureAlgorithms.isEmpty()) { if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource); return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
} }
else {
Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>(); Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
@ -324,7 +318,6 @@ public final class NimbusJwtDecoder implements JwtDecoder {
} }
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
} }
}
JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) { JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) {
if (this.cache == null) { if (this.cache == null) {
@ -339,13 +332,10 @@ public final class NimbusJwtDecoder implements JwtDecoder {
JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever); JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever);
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>(); ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource)); jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
// Spring Security validates the claim set independent from Nimbus // Spring Security validates the claim set independent from Nimbus
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {
}); });
this.jwtProcessorCustomizer.accept(jwtProcessor); this.jwtProcessorCustomizer.accept(jwtProcessor);
return jwtProcessor; return jwtProcessor;
} }
@ -397,10 +387,10 @@ public final class NimbusJwtDecoder implements JwtDecoder {
@Override @Override
public Resource retrieveResource(URL url) throws IOException { public Resource retrieveResource(URL url) throws IOException {
String jwkSet;
try { try {
jwkSet = this.cache.get(url.toString(), String jwkSet = this.cache.get(url.toString(),
() -> this.resourceRetriever.retrieveResource(url).getContent()); () -> this.resourceRetriever.retrieveResource(url).getContent());
return new Resource(jwkSet, "UTF-8");
} }
catch (Cache.ValueRetrievalException ex) { catch (Cache.ValueRetrievalException ex) {
Throwable thrownByValueLoader = ex.getCause(); Throwable thrownByValueLoader = ex.getCause();
@ -412,8 +402,6 @@ public final class NimbusJwtDecoder implements JwtDecoder {
catch (Exception ex) { catch (Exception ex) {
throw new IOException(ex); throw new IOException(ex);
} }
return new Resource(jwkSet, "UTF-8");
} }
} }
@ -433,21 +421,21 @@ public final class NimbusJwtDecoder implements JwtDecoder {
public Resource retrieveResource(URL url) throws IOException { public Resource retrieveResource(URL url) throws IOException {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
ResponseEntity<String> response = getResponse(url, headers);
if (response.getStatusCodeValue() != 200) {
throw new IOException(response.toString());
}
return new Resource(response.getBody(), "UTF-8");
}
ResponseEntity<String> response; private ResponseEntity<String> getResponse(URL url, HttpHeaders headers) throws IOException {
try { try {
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI()); RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI());
response = this.restOperations.exchange(request, String.class); return this.restOperations.exchange(request, String.class);
} }
catch (Exception ex) { catch (Exception ex) {
throw new IOException(ex); throw new IOException(ex);
} }
if (response.getStatusCodeValue() != 200) {
throw new IOException(response.toString());
}
return new Resource(response.getBody(), "UTF-8");
} }
} }
@ -506,22 +494,16 @@ public final class NimbusJwtDecoder implements JwtDecoder {
} }
JWTProcessor<SecurityContext> processor() { JWTProcessor<SecurityContext> processor() {
if (!JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm)) { Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm),
throw new IllegalStateException( () -> "The provided key is of type RSA; however the signature algorithm is of some other type: "
"The provided key is of type RSA; " + "however the signature algorithm is of some other type: "
+ this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512."); + this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512.");
}
JWSKeySelector<SecurityContext> jwsKeySelector = new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key); JWSKeySelector<SecurityContext> jwsKeySelector = new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key);
DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>(); DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector); jwtProcessor.setJWSKeySelector(jwsKeySelector);
// Spring Security validates the claim set independent from Nimbus // Spring Security validates the claim set independent from Nimbus
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {
}); });
this.jwtProcessorCustomizer.accept(jwtProcessor); this.jwtProcessorCustomizer.accept(jwtProcessor);
return jwtProcessor; return jwtProcessor;
} }
@ -599,13 +581,10 @@ public final class NimbusJwtDecoder implements JwtDecoder {
this.secretKey); this.secretKey);
DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>(); DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector); jwtProcessor.setJWSKeySelector(jwsKeySelector);
// Spring Security validates the claim set independent from Nimbus // Spring Security validates the claim set independent from Nimbus
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {
}); });
this.jwtProcessorCustomizer.accept(jwtProcessor); this.jwtProcessorCustomizer.accept(jwtProcessor);
return jwtProcessor; return jwtProcessor;
} }

View File

@ -79,7 +79,6 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm) { public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm) {
Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty"); Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty");
Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty"); Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty");
this.jwtDecoderBuilder = NimbusJwtDecoder.withJwkSetUri(jwkSetUrl) this.jwtDecoderBuilder = NimbusJwtDecoder.withJwkSetUri(jwkSetUrl)
.jwsAlgorithm(SignatureAlgorithm.from(jwsAlgorithm)); .jwsAlgorithm(SignatureAlgorithm.from(jwsAlgorithm));
this.delegate = makeDelegate(); this.delegate = makeDelegate();

View File

@ -161,8 +161,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
try { try {
return this.jwtProcessor.convert(parsedToken).map((set) -> createJwt(parsedToken, set)) return this.jwtProcessor.convert(parsedToken).map((set) -> createJwt(parsedToken, set))
.map(this::validateJwt) .map(this::validateJwt)
.onErrorMap((e) -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), .onErrorMap((ex) -> !(ex instanceof IllegalStateException) && !(ex instanceof JwtException),
(e) -> new JwtException("An error occurred while attempting to decode the Jwt: ", e)); (ex) -> new JwtException("An error occurred while attempting to decode the Jwt: ", ex));
} }
catch (JwtException ex) { catch (JwtException ex) {
throw ex; throw ex;
@ -176,7 +176,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
try { try {
Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject()); Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims());
return Jwt.withTokenValue(parsedJwt.getParsedString()).headers((h) -> h.putAll(headers)) return Jwt.withTokenValue(parsedJwt.getParsedString()).headers((h) -> h.putAll(headers))
.claims((c) -> c.putAll(claims)).build(); .claims((c) -> c.putAll(claims)).build();
} }
@ -189,19 +188,21 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt); OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt);
if (result.hasErrors()) { if (result.hasErrors()) {
Collection<OAuth2Error> errors = result.getErrors(); Collection<OAuth2Error> errors = result.getErrors();
String validationErrorString = "Unable to validate Jwt"; String validationErrorString = getJwtValidationExceptionMessage(errors);
for (OAuth2Error oAuth2Error : errors) {
if (!StringUtils.isEmpty(oAuth2Error.getDescription())) {
validationErrorString = oAuth2Error.getDescription();
break;
}
}
throw new JwtValidationException(validationErrorString, errors); throw new JwtValidationException(validationErrorString, errors);
} }
return jwt; return jwt;
} }
private String getJwtValidationExceptionMessage(Collection<OAuth2Error> errors) {
for (OAuth2Error oAuth2Error : errors) {
if (!StringUtils.isEmpty(oAuth2Error.getDescription())) {
return oAuth2Error.getDescription();
}
}
return "Unable to validate Jwt";
}
/** /**
* Use the given <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a> * Use the given <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a>
* uri to validate JWTs. * uri to validate JWTs.
@ -353,7 +354,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
if (this.signatureAlgorithms.isEmpty()) { if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource); return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
} }
else {
Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>(); Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) { for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName()); JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
@ -361,7 +361,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
} }
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
} }
}
Converter<JWT, Mono<JWTClaimsSet>> processor() { Converter<JWT, Mono<JWTClaimsSet>> processor() {
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet(); JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
@ -370,16 +369,14 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
jwtProcessor.setJWSKeySelector(jwsKeySelector); jwtProcessor.setJWSKeySelector(jwsKeySelector);
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {
}); });
this.jwtProcessorCustomizer.accept(jwtProcessor); this.jwtProcessorCustomizer.accept(jwtProcessor);
ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri); ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
source.setWebClient(this.webClient); source.setWebClient(this.webClient);
Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector); Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
return (jwt) -> { return (jwt) -> {
JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader()); JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
return source.get(selector).onErrorMap((e) -> new IllegalStateException("Could not obtain the keys", e)) return source.get(selector)
.onErrorMap((ex) -> new IllegalStateException("Could not obtain the keys", ex))
.map((jwkList) -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList))); .map((jwkList) -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList)));
}; };
} }
@ -396,7 +393,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
if (!expectedJwsAlgorithms.apply(jwsHeader.getAlgorithm())) { if (!expectedJwsAlgorithms.apply(jwsHeader.getAlgorithm())) {
throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm()); throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm());
} }
return new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader)); return new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
} }
@ -463,22 +459,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
} }
Converter<JWT, Mono<JWTClaimsSet>> processor() { Converter<JWT, Mono<JWTClaimsSet>> processor() {
if (!JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm)) { Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm),
throw new IllegalStateException( () -> "The provided key is of type RSA; however the signature algorithm is of some other type: "
"The provided key is of type RSA; " + "however the signature algorithm is of some other type: "
+ this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512."); + this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512.");
}
JWSKeySelector<SecurityContext> jwsKeySelector = new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key); JWSKeySelector<SecurityContext> jwsKeySelector = new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key);
DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>(); DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector); jwtProcessor.setJWSKeySelector(jwsKeySelector);
// Spring Security validates the claim set independent from Nimbus // Spring Security validates the claim set independent from Nimbus
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {
}); });
this.jwtProcessorCustomizer.accept(jwtProcessor); this.jwtProcessorCustomizer.accept(jwtProcessor);
return (jwt) -> Mono.just(createClaimsSet(jwtProcessor, jwt, null)); return (jwt) -> Mono.just(createClaimsSet(jwtProcessor, jwt, null));
} }
@ -550,13 +540,10 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
this.secretKey); this.secretKey);
DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>(); DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector); jwtProcessor.setJWSKeySelector(jwsKeySelector);
// Spring Security validates the claim set independent from Nimbus // Spring Security validates the claim set independent from Nimbus
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {
}); });
this.jwtProcessorCustomizer.accept(jwtProcessor); this.jwtProcessorCustomizer.accept(jwtProcessor);
return (jwt) -> Mono.just(createClaimsSet(jwtProcessor, jwt, null)); return (jwt) -> Mono.just(createClaimsSet(jwtProcessor, jwt, null));
} }
@ -626,9 +613,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
jwtProcessor.setJWSKeySelector(jwsKeySelector); jwtProcessor.setJWSKeySelector(jwsKeySelector);
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {
}); });
this.jwtProcessorCustomizer.accept(jwtProcessor); this.jwtProcessorCustomizer.accept(jwtProcessor);
return (jwt) -> { return (jwt) -> {
if (jwt instanceof SignedJWT) { if (jwt instanceof SignedJWT) {
return this.jwkSource.apply((SignedJWT) jwt) return this.jwkSource.apply((SignedJWT) jwt)

View File

@ -33,6 +33,9 @@ import org.springframework.util.Assert;
*/ */
public final class ReactiveJwtDecoders { public final class ReactiveJwtDecoders {
private ReactiveJwtDecoders() {
}
/** /**
* Creates a {@link ReactiveJwtDecoder} using the provided <a href= * Creates a {@link ReactiveJwtDecoder} using the provided <a href=
* "https://openid.net/specs/openid-connect-core-1_0.html#IssuerIdentifier">Issuer</a> * "https://openid.net/specs/openid-connect-core-1_0.html#IssuerIdentifier">Issuer</a>
@ -106,11 +109,7 @@ public final class ReactiveJwtDecoders {
NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder
.withJwkSetUri(configuration.get("jwks_uri").toString()).build(); .withJwkSetUri(configuration.get("jwks_uri").toString()).build();
jwtDecoder.setJwtValidator(jwtValidator); jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder; return jwtDecoder;
} }
private ReactiveJwtDecoders() {
}
} }

View File

@ -63,29 +63,23 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
return Mono.defer(() -> { return Mono.defer(() -> {
// Run the selector on the JWK set // Run the selector on the JWK set
List<JWK> matches = jwkSelector.select(jwkSet); List<JWK> matches = jwkSelector.select(jwkSet);
if (!matches.isEmpty()) { if (!matches.isEmpty()) {
// Success // Success
return Mono.just(matches); return Mono.just(matches);
} }
// Refresh the JWK set if the sought key ID is not in the cached JWK set // Refresh the JWK set if the sought key ID is not in the cached JWK set
// Looking for JWK with specific ID? // Looking for JWK with specific ID?
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher()); String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
if (soughtKeyID == null) { if (soughtKeyID == null) {
// No key ID specified, return no matches // No key ID specified, return no matches
return Mono.just(Collections.emptyList()); return Mono.just(Collections.emptyList());
} }
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) { if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
// The key ID exists in the cached JWK set, matching // The key ID exists in the cached JWK set, matching
// failed for some other reason, return no matches // failed for some other reason, return no matches
return Mono.just(Collections.emptyList()); return Mono.just(Collections.emptyList());
} }
return Mono.empty(); return Mono.empty();
}); });
} }
@ -114,13 +108,10 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
* @return The first key ID, {@code null} if none. * @return The first key ID, {@code null} if none.
*/ */
protected static String getFirstSpecifiedKeyID(final JWKMatcher jwkMatcher) { protected static String getFirstSpecifiedKeyID(final JWKMatcher jwkMatcher) {
Set<String> keyIDs = jwkMatcher.getKeyIDs(); Set<String> keyIDs = jwkMatcher.getKeyIDs();
if (keyIDs == null || keyIDs.isEmpty()) { if (keyIDs == null || keyIDs.isEmpty()) {
return null; return null;
} }
for (String id : keyIDs) { for (String id : keyIDs) {
if (id != null) { if (id != null) {
return id; return id;