Support overriding RestOperations in OidcIdTokenDecoderFactory

Closes gh-14178
This commit is contained in:
Armin Krezović 2023-12-20 21:24:13 +01:00 committed by Josh Cummings
parent 0041c658de
commit 9c352c4b4b
2 changed files with 44 additions and 3 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -49,6 +49,8 @@ import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
/**
* A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder} used for
@ -89,6 +91,9 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
private Function<ClientRegistration, RestOperations> restOperationsFactory = (
clientRegistration) -> new RestTemplate();
/**
* Returns the default {@link Converter}'s used for type conversion of claim values
* for an {@link OidcIdToken}.
@ -164,7 +169,10 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
return NimbusJwtDecoder.withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build();
return NimbusJwtDecoder.withJwkSetUri(jwkSetUri)
.jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm)
.restOperations(restOperationsFactory.apply(clientRegistration))
.build();
}
if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
@ -237,4 +245,18 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
this.claimTypeConverterFactory = claimTypeConverterFactory;
}
/**
* Sets the factory that provides a {@link RestOperations} used by
* {@link NimbusJwtDecoder} to coordinate with the authorization servers indicated in
* the <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a> uri.
* @param restOperationsFactory the factory that provides a {@link RestOperations}
* used by {@link NimbusJwtDecoder}
*
* @since 6.3
*/
public void setRestOperationsFactory(Function<ClientRegistration, RestOperations> restOperationsFactory) {
Assert.notNull(restOperationsFactory, "restOperationsFactory cannot be null");
this.restOperationsFactory = restOperationsFactory;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,6 +34,8 @@ import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -95,6 +97,12 @@ public class OidcIdTokenDecoderFactoryTests {
.isThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null));
}
@Test
public void setRestOperationsFactoryWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.idTokenDecoderFactory.setRestOperationsFactory(null));
}
@Test
public void createDecoderWhenClientRegistrationNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null));
@ -177,4 +185,15 @@ public class OidcIdTokenDecoderFactoryTests {
verify(customClaimTypeConverterFactory).apply(same(clientRegistration));
}
@Test
public void createDecoderWhenCustomRestOperationsFactorySetThenApplied() {
Function<ClientRegistration, RestOperations> customRestOperationsFactory = mock(
Function.class);
this.idTokenDecoderFactory.setRestOperationsFactory(customRestOperationsFactory);
ClientRegistration clientRegistration = this.registration.build();
given(customRestOperationsFactory.apply(same(clientRegistration)))
.willReturn(new RestTemplate());
this.idTokenDecoderFactory.createDecoder(clientRegistration);
verify(customRestOperationsFactory).apply(same(clientRegistration));
}
}