Jwt Claim Mapping

This introduces a hook for users to customize standard Jwt Claim
values in cases where the JWT issuer isn't spec compliant or where the
user needs to add or remove claims.

Fixes: gh-5223
This commit is contained in:
Josh Cummings 2018-08-10 17:43:26 -06:00 committed by Rob Winch
parent 2495025845
commit 9e0f171d47
4 changed files with 510 additions and 14 deletions

View File

@ -0,0 +1,241 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.jwt;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.springframework.core.convert.converter.Converter;
import org.springframework.util.Assert;
/**
* Converts a JWT claim set, claim by claim. Can be configured with custom converters
* by claim name.
*
* @author Josh Cummings
* @since 5.1
*/
public final class MappedJwtClaimSetConverter
implements Converter<Map<String, Object>, Map<String, Object>> {
private static final Converter<Object, Collection<String>> AUDIENCE_CONVERTER = new AudienceConverter();
private static final Converter<Object, URL> ISSUER_CONVERTER = new IssuerConverter();
private static final Converter<Object, String> STRING_CONVERTER = new StringConverter();
private static final Converter<Object, Instant> TEMPORAL_CONVERTER = new InstantConverter();
private final Map<String, Converter<Object, ?>> claimConverters;
/**
* Constructs a {@link MappedJwtClaimSetConverter} with the provided arguments
*
* This will completely replace any set of default converters.
*
* @param claimConverters The {@link Map} of converters to use
*/
public MappedJwtClaimSetConverter(Map<String, Converter<Object, ?>> claimConverters) {
Assert.notNull(claimConverters, "claimConverters cannot be null");
this.claimConverters = new HashMap<>(claimConverters);
}
/**
* Construct a {@link MappedJwtClaimSetConverter}, overriding individual claim
* converters with the provided {@link Map} of {@link Converter}s.
*
* For example, the following would give an instance that is configured with only the default
* claim converters:
*
* <pre>
* MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
* </pre>
*
* Or, the following would supply a custom converter for the subject, leaving the other defaults
* in place:
*
* <pre>
* MappedJwtClaimsSetConverter.withDefaults(
* Collections.singletonMap(JwtClaimNames.SUB, new UserDetailsServiceJwtSubjectConverter()));
* </pre>
*
* To completely replace the underlying {@link Map} of converters, {@see MappedJwtClaimSetConverter(Map)}.
*
* @param claimConverters
* @return An instance of {@link MappedJwtClaimSetConverter} that contains the converters provided,
* plus any defaults that were not overridden.
*/
public static MappedJwtClaimSetConverter withDefaults
(Map<String, Converter<Object, ?>> claimConverters) {
Assert.notNull(claimConverters, "claimConverters cannot be null");
Map<String, Converter<Object, ?>> claimNameToConverter = new HashMap<>();
claimNameToConverter.put(JwtClaimNames.AUD, AUDIENCE_CONVERTER);
claimNameToConverter.put(JwtClaimNames.EXP, TEMPORAL_CONVERTER);
claimNameToConverter.put(JwtClaimNames.IAT, TEMPORAL_CONVERTER);
claimNameToConverter.put(JwtClaimNames.ISS, ISSUER_CONVERTER);
claimNameToConverter.put(JwtClaimNames.JTI, STRING_CONVERTER);
claimNameToConverter.put(JwtClaimNames.NBF, TEMPORAL_CONVERTER);
claimNameToConverter.put(JwtClaimNames.SUB, STRING_CONVERTER);
claimNameToConverter.putAll(claimConverters);
return new MappedJwtClaimSetConverter(claimNameToConverter);
}
/**
* {@inheritDoc}
*/
@Override
public Map<String, Object> convert(Map<String, Object> claims) {
Assert.notNull(claims, "claims cannot be null");
Map<String, Object> mappedClaims = new HashMap<>(claims);
for (Map.Entry<String, Converter<Object, ?>> entry : this.claimConverters.entrySet()) {
String claimName = entry.getKey();
Converter<Object, ?> converter = entry.getValue();
if (converter != null) {
Object claim = claims.get(claimName);
Object mappedClaim = converter.convert(claim);
mappedClaims.compute(claimName, (key, value) -> mappedClaim);
}
}
Instant issuedAt = (Instant) mappedClaims.get(JwtClaimNames.IAT);
Instant expiresAt = (Instant) mappedClaims.get(JwtClaimNames.EXP);
if (issuedAt == null && expiresAt != null) {
mappedClaims.put(JwtClaimNames.IAT, expiresAt.minusSeconds(1));
}
return mappedClaims;
}
/**
* Coerces an <a target="_blank" href="https://tools.ietf.org/html/rfc7519#section-4.1.3">Audience</a> claim
* into a {@link Collection<String>}, ignoring null values, and throwing an error if its coercion efforts fail.
*/
private static class AudienceConverter implements Converter<Object, Collection<String>> {
@Override
public Collection<String> convert(Object source) {
if (source == null) {
return null;
}
if (source instanceof Collection) {
return ((Collection<?>) source).stream()
.filter(Objects::nonNull)
.map(Objects::toString)
.collect(Collectors.toList());
}
return Arrays.asList(source.toString());
}
}
/**
* Coerces an <a target="_blank" href="https://tools.ietf.org/html/rfc7519#section-4.1.1">Issuer</a> claim
* into a {@link URL}, ignoring null values, and throwing an error if its coercion efforts fail.
*/
private static class IssuerConverter implements Converter<Object, URL> {
@Override
public URL convert(Object source) {
if (source == null) {
return null;
}
if (source instanceof URL) {
return (URL) source;
}
if (source instanceof URI) {
return toUrl((URI) source);
}
return toUrl(source.toString());
}
private URL toUrl(URI source) {
try {
return source.toURL();
} catch (MalformedURLException e) {
throw new IllegalStateException("Could not coerce " + source + " into a URL", e);
}
}
private URL toUrl(String source) {
try {
return new URL(source);
} catch (MalformedURLException e) {
throw new IllegalStateException("Could not coerce " + source + " into a URL", e);
}
}
}
/**
* Coerces a claim into an {@link Instant}, ignoring null values, and throwing an error
* if its coercion efforts fail.
*/
private static class InstantConverter implements Converter<Object, Instant> {
@Override
public Instant convert(Object source) {
if (source == null) {
return null;
}
if (source instanceof Instant) {
return (Instant) source;
}
if (source instanceof Date) {
return ((Date) source).toInstant();
}
if (source instanceof Number) {
return Instant.ofEpochSecond(((Number) source).longValue());
}
try {
return Instant.ofEpochSecond(Long.parseLong(source.toString()));
} catch (Exception e) {
throw new IllegalStateException("Could not coerce " + source + " into an Instant", e);
}
}
}
/**
* Coerces a claim into a {@link String}, ignoring null values, and throwing an error if its
* coercion efforts fail.
*/
private static class StringConverter implements Converter<Object, String> {
@Override
public String convert(Object source) {
if (source == null) {
return null;
}
return source.toString();
}
}
}

View File

@ -40,6 +40,7 @@ import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
@ -78,8 +79,11 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
private final ConfigurableJWTProcessor<SecurityContext> jwtProcessor;
private final RestOperationsResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever();
private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter =
MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
/**
* Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters.
*
@ -134,6 +138,16 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
this.jwtValidator = jwtValidator;
}
/**
* Use the following {@link Converter} for manipulating the JWT's claim set
*
* @param claimSetConverter the {@link Converter} to use
*/
public final void setClaimSetConverter(Converter<Map<String, Object>, Map<String, Object>> claimSetConverter) {
Assert.notNull(claimSetConverter, "claimSetConverter cannot be null");
this.claimSetConverter = claimSetConverter;
}
private JWT parse(String token) {
try {
return JWTParser.parse(token);
@ -149,22 +163,12 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
// Verify the signature
JWTClaimsSet jwtClaimsSet = this.jwtProcessor.process(parsedJwt, null);
Instant expiresAt = null;
if (jwtClaimsSet.getExpirationTime() != null) {
expiresAt = jwtClaimsSet.getExpirationTime().toInstant();
}
Instant issuedAt = null;
if (jwtClaimsSet.getIssueTime() != null) {
issuedAt = jwtClaimsSet.getIssueTime().toInstant();
} else if (expiresAt != null) {
// Default to expiresAt - 1 second
issuedAt = Instant.from(expiresAt).minusSeconds(1);
}
Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims());
jwt = new Jwt(token, issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
Instant expiresAt = (Instant) claims.get(JwtClaimNames.EXP);
Instant issuedAt = (Instant) claims.get(JwtClaimNames.IAT);
jwt = new Jwt(token, issuedAt, expiresAt, headers, claims);
} catch (RemoteKeySourceException ex) {
if (ex.getCause() instanceof ParseException) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"));

View File

@ -0,0 +1,223 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.jwt;
import java.net.URI;
import java.net.URL;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.springframework.core.convert.converter.Converter;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link MappedJwtClaimSetConverter}
*
* @author Josh Cummings
*/
public class MappedJwtClaimSetConverterTests {
@Test
public void convertWhenUsingCustomExpiresAtConverterThenIssuedAtConverterStillConsultsIt() {
Instant at = Instant.ofEpochMilli(1000000000000L);
Converter<Object, Instant> expiresAtConverter = mock(Converter.class);
when(expiresAtConverter.convert(any())).thenReturn(at);
MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
.withDefaults(Collections.singletonMap(JwtClaimNames.EXP, expiresAtConverter));
Map<String, Object> source = new HashMap<>();
Map<String, Object> target = converter.convert(source);
assertThat(target.get(JwtClaimNames.IAT)).
isEqualTo(Instant.ofEpochMilli(at.toEpochMilli()).minusSeconds(1));
}
@Test
public void convertWhenUsingDefaultsThenBasesIssuedAtOffOfExpiration() {
MappedJwtClaimSetConverter converter =
MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
Map<String, Object> source = Collections.singletonMap(JwtClaimNames.EXP, 1000000000L);
Map<String, Object> target = converter.convert(source);
assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(Instant.ofEpochSecond(1000000000L));
assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1000000000L).minusSeconds(1));
}
@Test
public void convertWhenUsingDefaultsThenCoercesAudienceAccordingToJwtSpec() {
MappedJwtClaimSetConverter converter =
MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
Map<String, Object> source = Collections.singletonMap(JwtClaimNames.AUD, "audience");
Map<String, Object> target = converter.convert(source);
assertThat(target.get(JwtClaimNames.AUD)).isInstanceOf(Collection.class);
assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("audience"));
source = Collections.singletonMap(JwtClaimNames.AUD, Arrays.asList("one", "two"));
target = converter.convert(source);
assertThat(target.get(JwtClaimNames.AUD)).isInstanceOf(Collection.class);
assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("one", "two"));
}
@Test
public void convertWhenUsingDefaultsThenCoercesAllAttributesInJwtSpec() throws Exception {
MappedJwtClaimSetConverter converter =
MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
source.put(JwtClaimNames.JTI, 1);
source.put(JwtClaimNames.AUD, "audience");
source.put(JwtClaimNames.EXP, 2000000000L);
source.put(JwtClaimNames.IAT, new Date(1000000000000L));
source.put(JwtClaimNames.ISS, "https://any.url");
source.put(JwtClaimNames.NBF, 1000000000);
source.put(JwtClaimNames.SUB, 1234);
Map<String, Object> target = converter.convert(source);
assertThat(target.get(JwtClaimNames.JTI)).isEqualTo("1");
assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("audience"));
assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(Instant.ofEpochSecond(2000000000L));
assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1000000000L));
assertThat(target.get(JwtClaimNames.ISS)).isEqualTo(new URL("https://any.url"));
assertThat(target.get(JwtClaimNames.NBF)).isEqualTo(Instant.ofEpochSecond(1000000000L));
assertThat(target.get(JwtClaimNames.SUB)).isEqualTo("1234");
}
@Test
public void convertWhenUsingCustomConverterThenAllOtherDefaultsAreStillUsed() throws Exception {
Converter<Object, String> claimConverter = mock(Converter.class);
MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
.withDefaults(Collections.singletonMap(JwtClaimNames.SUB, claimConverter));
when(claimConverter.convert(any(Object.class))).thenReturn("1234");
Map<String, Object> source = new HashMap<>();
source.put(JwtClaimNames.JTI, 1);
source.put(JwtClaimNames.AUD, "audience");
source.put(JwtClaimNames.EXP, Instant.ofEpochSecond(2000000000L));
source.put(JwtClaimNames.IAT, new Date(1000000000000L));
source.put(JwtClaimNames.ISS, URI.create("https://any.url"));
source.put(JwtClaimNames.NBF, "1000000000");
source.put(JwtClaimNames.SUB, 2345);
Map<String, Object> target = converter.convert(source);
assertThat(target.get(JwtClaimNames.JTI)).isEqualTo("1");
assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("audience"));
assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(Instant.ofEpochSecond(2000000000L));
assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1000000000L));
assertThat(target.get(JwtClaimNames.ISS)).isEqualTo(new URL("https://any.url"));
assertThat(target.get(JwtClaimNames.NBF)).isEqualTo(Instant.ofEpochSecond(1000000000L));
assertThat(target.get(JwtClaimNames.SUB)).isEqualTo("1234");
}
@Test
public void convertWhenConverterReturnsNullThenClaimIsRemoved() {
MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
.withDefaults(Collections.emptyMap());
Map<String, Object> source = Collections.singletonMap(JwtClaimNames.ISS, null);
Map<String, Object> target = converter.convert(source);
assertThat(target).doesNotContainKey(JwtClaimNames.ISS);
}
@Test
public void convertWhenConverterReturnsValueWhenEntryIsMissingThenEntryIsAdded() {
Converter<Object, String> claimConverter = mock(Converter.class);
MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
.withDefaults(Collections.singletonMap("custom-claim", claimConverter));
when(claimConverter.convert(any())).thenReturn("custom-value");
Map<String, Object> source = new HashMap<>();
Map<String, Object> target = converter.convert(source);
assertThat(target.get("custom-claim")).isEqualTo("custom-value");
}
@Test
public void convertWhenUsingConstructorThenOnlyConvertersInThatMapAreUsedForConversion() {
Converter<Object, String> claimConverter = mock(Converter.class);
MappedJwtClaimSetConverter converter = new MappedJwtClaimSetConverter(
Collections.singletonMap(JwtClaimNames.SUB, claimConverter));
when(claimConverter.convert(any(Object.class))).thenReturn("1234");
Map<String, Object> source = new HashMap<>();
source.put(JwtClaimNames.JTI, new Object());
source.put(JwtClaimNames.AUD, new Object());
source.put(JwtClaimNames.EXP, Instant.ofEpochSecond(1L));
source.put(JwtClaimNames.IAT, Instant.ofEpochSecond(1L));
source.put(JwtClaimNames.ISS, new Object());
source.put(JwtClaimNames.NBF, new Object());
source.put(JwtClaimNames.SUB, new Object());
Map<String, Object> target = converter.convert(source);
assertThat(target.get(JwtClaimNames.JTI)).isEqualTo(source.get(JwtClaimNames.JTI));
assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(source.get(JwtClaimNames.AUD));
assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(source.get(JwtClaimNames.EXP));
assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(source.get(JwtClaimNames.IAT));
assertThat(target.get(JwtClaimNames.ISS)).isEqualTo(source.get(JwtClaimNames.ISS));
assertThat(target.get(JwtClaimNames.NBF)).isEqualTo(source.get(JwtClaimNames.NBF));
assertThat(target.get(JwtClaimNames.SUB)).isEqualTo("1234");
}
@Test
public void convertWhenUsingDefaultsThenFailedConversionThrowsIllegalStateException() {
MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
.withDefaults(Collections.emptyMap());
Map<String, Object> badIssuer = Collections.singletonMap(JwtClaimNames.ISS, "badly-formed-iss");
assertThatCode(() -> converter.convert(badIssuer)).isInstanceOf(IllegalStateException.class);
Map<String, Object> badIssuedAt = Collections.singletonMap(JwtClaimNames.IAT, "badly-formed-iat");
assertThatCode(() -> converter.convert(badIssuedAt)).isInstanceOf(IllegalStateException.class);
Map<String, Object> badExpiresAt = Collections.singletonMap(JwtClaimNames.EXP, "badly-formed-exp");
assertThatCode(() -> converter.convert(badExpiresAt)).isInstanceOf(IllegalStateException.class);
Map<String, Object> badNotBefore = Collections.singletonMap(JwtClaimNames.NBF, "badly-formed-nbf");
assertThatCode(() -> converter.convert(badNotBefore)).isInstanceOf(IllegalStateException.class);
}
@Test
public void constructWhenAnyParameterIsNullThenIllegalArgumentException() {
assertThatCode(() -> new MappedJwtClaimSetConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void withDefaultsWhenAnyParameterIsNullThenIllegalArgumentException() {
assertThatCode(() -> MappedJwtClaimSetConverter.withDefaults(null))
.isInstanceOf(IllegalArgumentException.class);
}
}

View File

@ -16,6 +16,8 @@
package org.springframework.security.oauth2.jwt;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
@ -33,6 +35,7 @@ import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
@ -40,6 +43,7 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
import org.springframework.web.client.RestTemplate;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
@ -228,4 +232,28 @@ public class NimbusJwtDecoderJwkSupportTests {
.hasFieldOrPropertyWithValue("errors", Arrays.asList(firstFailure, secondFailure));
}
}
@Test
public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() throws Exception {
try ( MockWebServer server = new MockWebServer() ) {
server.enqueue(new MockResponse().setBody(JWK_SET));
String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = mock(Converter.class);
when(claimSetConverter.convert(any(Map.class))).thenReturn(Collections.singletonMap("custom", "value"));
decoder.setClaimSetConverter(claimSetConverter);
Jwt jwt = decoder.decode(SIGNED_JWT);
assertThat(jwt.getClaims().size()).isEqualTo(1);
assertThat(jwt.getClaims().get("custom")).isEqualTo("value");
}
}
@Test
public void setClaimSetConverterWhenIsNullThenThrowsIllegalArgumentException() {
assertThatCode(() -> jwtDecoder.setClaimSetConverter(null))
.isInstanceOf(IllegalArgumentException.class);
}
}