diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java index d00490bf23..c66994df16 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,8 @@ import org.springframework.security.oauth2.jwt.JwtTimestampValidator; import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; /** * A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder} used for @@ -89,6 +91,9 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory, Map>> claimTypeConverterFactory = ( clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; + private Function restOperationsFactory = ( + clientRegistration) -> new RestTemplate(); + /** * Returns the default {@link Converter}'s used for type conversion of claim values * for an {@link OidcIdToken}. @@ -164,7 +169,10 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactoryJWK Set uri. + * @param restOperationsFactory the factory that provides a {@link RestOperations} + * used by {@link NimbusJwtDecoder} + * + * @since 6.3 + */ + public void setRestOperationsFactory(Function restOperationsFactory) { + Assert.notNull(restOperationsFactory, "restOperationsFactory cannot be null"); + this.restOperationsFactory = restOperationsFactory; + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java index 33663bac65..0f6015e410 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,8 @@ import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -95,6 +97,12 @@ public class OidcIdTokenDecoderFactoryTests { .isThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null)); } + @Test + public void setRestOperationsFactoryWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.idTokenDecoderFactory.setRestOperationsFactory(null)); + } + @Test public void createDecoderWhenClientRegistrationNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null)); @@ -177,4 +185,15 @@ public class OidcIdTokenDecoderFactoryTests { verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + @Test + public void createDecoderWhenCustomRestOperationsFactorySetThenApplied() { + Function customRestOperationsFactory = mock( + Function.class); + this.idTokenDecoderFactory.setRestOperationsFactory(customRestOperationsFactory); + ClientRegistration clientRegistration = this.registration.build(); + given(customRestOperationsFactory.apply(same(clientRegistration))) + .willReturn(new RestTemplate()); + this.idTokenDecoderFactory.createDecoder(clientRegistration); + verify(customRestOperationsFactory).apply(same(clientRegistration)); + } }