Expose RestOperations in NimbusJwtDecoderJwkSupport

Fixes gh-5603
This commit is contained in:
Joe Grandja 2018-08-08 13:20:45 -04:00
parent 11984039c2
commit 16fe1c5b52
2 changed files with 102 additions and 49 deletions

View File

@ -15,13 +15,6 @@
*/ */
package org.springframework.security.oauth2.jwt; package org.springframework.security.oauth2.jwt;
import java.net.MalformedURLException;
import java.net.URL;
import java.text.ParseException;
import java.time.Instant;
import java.util.LinkedHashMap;
import java.util.Map;
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.jwk.source.JWKSource;
@ -29,7 +22,7 @@ import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.DefaultResourceRetriever; import com.nimbusds.jose.util.Resource;
import com.nimbusds.jose.util.ResourceRetriever; import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
@ -37,12 +30,27 @@ import com.nimbusds.jwt.JWTParser;
import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.text.ParseException;
import java.time.Instant;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
/** /**
* An implementation of a {@link JwtDecoder} that "decodes" a * An implementation of a {@link JwtDecoder} that "decodes" a
* JSON Web Token (JWT) and additionally verifies it's digital signature if the JWT is a * JSON Web Token (JWT) and additionally verifies it's digital signature if the JWT is a
* JSON Web Signature (JWS). The public key used for verification is obtained from the * JSON Web Signature (JWS). The public key used for verification is obtained from the
* JSON Web Key (JWK) Set {@code URL} supplied via the constructor. * JSON Web Key (JWK) Set {@code URL} supplied via the constructor.
@ -63,9 +71,9 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
private static final String DECODING_ERROR_MESSAGE_TEMPLATE = private static final String DECODING_ERROR_MESSAGE_TEMPLATE =
"An error occurred while attempting to decode the Jwt: %s"; "An error occurred while attempting to decode the Jwt: %s";
private final URL jwkSetUrl;
private final JWSAlgorithm jwsAlgorithm; private final JWSAlgorithm jwsAlgorithm;
private final ConfigurableJWTProcessor<SecurityContext> jwtProcessor; private final ConfigurableJWTProcessor<SecurityContext> jwtProcessor;
private final RestOperationsResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever();
/** /**
* Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters. * Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters.
@ -85,18 +93,15 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm) { public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm) {
Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty"); Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty");
Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty"); Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty");
JWKSource jwkSource;
try { try {
this.jwkSetUrl = new URL(jwkSetUrl); jwkSource = new RemoteJWKSet(new URL(jwkSetUrl), this.jwkSetRetriever);
} catch (MalformedURLException ex) { } catch (MalformedURLException ex) {
throw new IllegalArgumentException("Invalid JWK Set URL " + jwkSetUrl + " : " + ex.getMessage(), ex); throw new IllegalArgumentException("Invalid JWK Set URL \"" + jwkSetUrl + "\" : " + ex.getMessage(), ex);
} }
this.jwsAlgorithm = JWSAlgorithm.parse(jwsAlgorithm); this.jwsAlgorithm = JWSAlgorithm.parse(jwsAlgorithm);
ResourceRetriever jwkSetRetriever = new DefaultResourceRetriever(30000, 30000);
JWKSource jwkSource = new RemoteJWKSet(this.jwkSetUrl, jwkSetRetriever);
JWSKeySelector<SecurityContext> jwsKeySelector = JWSKeySelector<SecurityContext> jwsKeySelector =
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource); new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
this.jwtProcessor = new DefaultJWTProcessor<>(); this.jwtProcessor = new DefaultJWTProcessor<>();
this.jwtProcessor.setJWSKeySelector(jwsKeySelector); this.jwtProcessor.setJWSKeySelector(jwsKeySelector);
} }
@ -104,10 +109,9 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
@Override @Override
public Jwt decode(String token) throws JwtException { public Jwt decode(String token) throws JwtException {
JWT jwt = this.parse(token); JWT jwt = this.parse(token);
if ( jwt instanceof SignedJWT ) { if (jwt instanceof SignedJWT) {
return this.createJwt(token, jwt); return this.createJwt(token, jwt);
} }
throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
} }
@ -158,4 +162,39 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
return jwt; return jwt;
} }
/**
* Sets the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set.
*
* @since 5.1
* @param restOperations the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set
*/
public final void setRestOperations(RestOperations restOperations) {
Assert.notNull(restOperations, "restOperations cannot be null");
this.jwkSetRetriever.restOperations = restOperations;
}
private static class RestOperationsResourceRetriever implements ResourceRetriever {
private RestOperations restOperations = new RestTemplate();
@Override
public Resource retrieveResource(URL url) throws IOException {
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
ResponseEntity<String> response;
try {
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI());
response = this.restOperations.exchange(request, String.class);
} catch (Exception ex) {
throw new IOException(ex);
}
if (response.getStatusCodeValue() != 200) {
throw new IOException(response.toString());
}
return new Resource(response.getBody(), "UTF-8");
}
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -24,23 +24,22 @@ import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.MockWebServer;
import org.assertj.core.api.Assertions;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
import org.springframework.web.client.RestTemplate;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode; import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.*;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.powermock.api.mockito.PowerMockito.mockStatic; import static org.mockito.Mockito.verify;
import static org.powermock.api.mockito.PowerMockito.when; import static org.powermock.api.mockito.PowerMockito.*;
import static org.powermock.api.mockito.PowerMockito.whenNew;
/** /**
* Tests for {@link NimbusJwtDecoderJwkSupport}. * Tests for {@link NimbusJwtDecoderJwkSupport}.
@ -62,6 +61,8 @@ public class NimbusJwtDecoderJwkSupportTests {
private static final String MALFORMED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJuYmYiOnt9LCJleHAiOjQ2ODQyMjUwODd9.guoQvujdWvd3xw7FYQEn4D6-gzM_WqFvXdmvAUNSLbxG7fv2_LLCNujPdrBHJoYPbOwS1BGNxIKQWS1tylvqzmr1RohQ-RZ2iAM1HYQzboUlkoMkcd8ENM__ELqho8aNYBfqwkNdUOyBFoy7Syu_w2SoJADw2RTjnesKO6CVVa05bW118pDS4xWxqC4s7fnBjmZoTn4uQ-Kt9YSQZQk8YQxkJSiyanozzgyfgXULA6mPu1pTNU3FVFaK1i1av_xtH_zAPgb647ZeaNe4nahgqC5h8nhOlm8W2dndXbwAt29nd2ZWBsru_QwZz83XSKLhTPFz-mPBByZZDsyBbIHf9A"; private static final String MALFORMED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJuYmYiOnt9LCJleHAiOjQ2ODQyMjUwODd9.guoQvujdWvd3xw7FYQEn4D6-gzM_WqFvXdmvAUNSLbxG7fv2_LLCNujPdrBHJoYPbOwS1BGNxIKQWS1tylvqzmr1RohQ-RZ2iAM1HYQzboUlkoMkcd8ENM__ELqho8aNYBfqwkNdUOyBFoy7Syu_w2SoJADw2RTjnesKO6CVVa05bW118pDS4xWxqC4s7fnBjmZoTn4uQ-Kt9YSQZQk8YQxkJSiyanozzgyfgXULA6mPu1pTNU3FVFaK1i1av_xtH_zAPgb647ZeaNe4nahgqC5h8nhOlm8W2dndXbwAt29nd2ZWBsru_QwZz83XSKLhTPFz-mPBByZZDsyBbIHf9A";
private static final String UNSIGNED_JWT = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9."; private static final String UNSIGNED_JWT = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9.";
private NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
@Test @Test
public void constructorWhenJwkSetUrlIsNullThenThrowIllegalArgumentException() { public void constructorWhenJwkSetUrlIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport(null)) assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport(null))
@ -80,10 +81,15 @@ public class NimbusJwtDecoderJwkSupportTests {
.isInstanceOf(IllegalArgumentException.class); .isInstanceOf(IllegalArgumentException.class);
} }
@Test
public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
Assertions.assertThatThrownBy(() -> this.jwtDecoder.setRestOperations(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test @Test
public void decodeWhenJwtInvalidThenThrowJwtException() { public void decodeWhenJwtInvalidThenThrowJwtException() {
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM); assertThatThrownBy(() -> this.jwtDecoder.decode("invalid"))
assertThatThrownBy(() -> jwtDecoder.decode("invalid"))
.isInstanceOf(JwtException.class); .isInstanceOf(JwtException.class);
} }
@ -103,16 +109,14 @@ public class NimbusJwtDecoderJwkSupportTests {
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().audience("resource1").build(); JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().audience("resource1").build();
when(jwtProcessor.process(any(JWT.class), eq(null))).thenReturn(jwtClaimsSet); when(jwtProcessor.process(any(JWT.class), eq(null))).thenReturn(jwtClaimsSet);
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM); NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL);
assertThatCode(() -> jwtDecoder.decode("encoded-jwt")).doesNotThrowAnyException(); assertThatCode(() -> jwtDecoder.decode("encoded-jwt")).doesNotThrowAnyException();
} }
// gh-5457 // gh-5457
@Test @Test
public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() throws Exception { public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() {
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM); assertThatCode(() -> this.jwtDecoder.decode(UNSIGNED_JWT))
assertThatCode(() -> jwtDecoder.decode(UNSIGNED_JWT))
.isInstanceOf(JwtException.class) .isInstanceOf(JwtException.class)
.hasMessageContaining("Unsupported algorithm of none"); .hasMessageContaining("Unsupported algorithm of none");
} }
@ -122,12 +126,11 @@ public class NimbusJwtDecoderJwkSupportTests {
try ( MockWebServer server = new MockWebServer() ) { try ( MockWebServer server = new MockWebServer() ) {
server.enqueue(new MockResponse().setBody(JWK_SET)); server.enqueue(new MockResponse().setBody(JWK_SET));
String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); assertThatCode(() -> jwtDecoder.decode(MALFORMED_JWT))
assertThatCode(() -> decoder.decode(MALFORMED_JWT))
.isInstanceOf(JwtException.class) .isInstanceOf(JwtException.class)
.hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload"); .hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload");
server.shutdown();
} }
} }
@ -136,28 +139,39 @@ public class NimbusJwtDecoderJwkSupportTests {
try ( MockWebServer server = new MockWebServer() ) { try ( MockWebServer server = new MockWebServer() ) {
server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET)); server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
assertThatCode(() -> decoder.decode(SIGNED_JWT))
.isInstanceOf(JwtException.class) .isInstanceOf(JwtException.class)
.hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set"); .hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set");
server.shutdown();
} }
} }
@Test @Test
public void decodeWhenJwkEndpointIsUnresponsiveThenRetrunsJwtException() throws Exception { public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws Exception {
try ( MockWebServer server = new MockWebServer() ) { try ( MockWebServer server = new MockWebServer() ) {
server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET)); server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
server.shutdown();
assertThatCode(() -> decoder.decode(SIGNED_JWT))
.isInstanceOf(JwtException.class) .isInstanceOf(JwtException.class)
.hasMessageContaining("An error occurred while attempting to decode the Jwt"); .hasMessageContaining("An error occurred while attempting to decode the Jwt");
server.shutdown();
}
}
// gh-5603
@Test
public void decodeWhenCustomRestOperationsSetThenUsed() throws Exception {
try ( MockWebServer server = new MockWebServer() ) {
server.enqueue(new MockResponse().setBody(JWK_SET));
String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
RestTemplate restTemplate = spy(new RestTemplate());
jwtDecoder.setRestOperations(restTemplate);
assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)).doesNotThrowAnyException();
verify(restTemplate).exchange(any(RequestEntity.class), eq(String.class));
server.shutdown();
} }
} }
} }