Reactive Jwt Claim Set Converter Support

Exposes setClaimSetConverter on NimbusReactiveJwtDecoder, lining it up
with the same support on NimbusJwtDecoder.

Fixes: gh-6015
This commit is contained in:
Josh Cummings 2018-11-12 14:47:56 -07:00 committed by Rob Winch
parent 11b6b63364
commit ae74f22e30
2 changed files with 43 additions and 15 deletions

View File

@ -17,6 +17,7 @@ package org.springframework.security.oauth2.jwt;
import java.security.interfaces.RSAPublicKey;
import java.time.Instant;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@ -40,6 +41,7 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor;
import reactor.core.publisher.Mono;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
@ -70,6 +72,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
private final JWKSelectorFactory jwkSelectorFactory;
private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = MappedJwtClaimSetConverter
.withDefaults(Collections.emptyMap());
public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
@ -122,6 +126,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
this.jwtValidator = jwtValidator;
}
/**
* Use the following {@link Converter} for manipulating the JWT's claim set
*
* @param claimSetConverter the {@link Converter} to use
*/
public void setClaimSetConverter(Converter<Map<String, Object>, Map<String, Object>> claimSetConverter) {
Assert.notNull(claimSetConverter, "claimSetConverter cannot be null");
this.claimSetConverter = claimSetConverter;
}
@Override
public Mono<Jwt> decode(String token) throws JwtException {
JWT jwt = parse(token);
@ -164,21 +178,12 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
}
private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) {
Instant expiresAt = null;
if (jwtClaimsSet.getExpirationTime() != null) {
expiresAt = jwtClaimsSet.getExpirationTime().toInstant();
}
Instant issuedAt = null;
if (jwtClaimsSet.getIssueTime() != null) {
issuedAt = jwtClaimsSet.getIssueTime().toInstant();
} else if (expiresAt != null) {
// Default to expiresAt - 1 second
issuedAt = Instant.from(expiresAt).minusSeconds(1);
}
Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims());
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
Instant expiresAt = (Instant) claims.get(JwtClaimNames.EXP);
Instant issuedAt = (Instant) claims.get(JwtClaimNames.IAT);
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, claims);
}
private Jwt validateJwt(Jwt jwt) {

View File

@ -20,8 +20,10 @@ import java.net.UnknownHostException;
import java.security.KeyFactory;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.X509EncodedKeySpec;
import java.time.Instant;
import java.util.Base64;
import java.util.Date;
import java.util.Collections;
import java.util.Map;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
@ -29,6 +31,7 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
@ -37,6 +40,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
@ -115,7 +119,7 @@ public class NimbusReactiveJwtDecoderTests {
Jwt jwt = this.decoder.decode(withIssuedAt).block();
assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(new Date(1529942448000L));
assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1529942448L));
}
@Test
@ -177,9 +181,28 @@ public class NimbusReactiveJwtDecoderTests {
.hasMessageContaining("mock-description");
}
@Test
public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() {
Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = mock(Converter.class);
this.decoder.setClaimSetConverter(claimSetConverter);
when(claimSetConverter.convert(any(Map.class))).thenReturn(Collections.singletonMap("custom", "value"));
Jwt jwt = this.decoder.decode(this.messageReadToken).block();
assertThat(jwt.getClaims().size()).isEqualTo(1);
assertThat(jwt.getClaims().get("custom")).isEqualTo("value");
verify(claimSetConverter).convert(any(Map.class));
}
@Test
public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() {
assertThatCode(() -> this.decoder.setJwtValidator(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setClaimSetConverterWhenNullThrowsIllegalArgumentException() {
assertThatCode(() -> this.decoder.setClaimSetConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}
}