diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index 24b24909c9..3af7db6d02 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -98,6 +98,10 @@ public final class OAuth2ClientConfigurer> private AuthorizationCodeGrantConfigurer authorizationCodeGrantConfigurer = new AuthorizationCodeGrantConfigurer(); + private ClientRegistrationRepository clientRegistrationRepository; + + private OAuth2AuthorizedClientRepository authorizedClientRepository; + /** * Sets the repository of client registrations. * @param clientRegistrationRepository the repository of client registrations @@ -107,6 +111,7 @@ public final class OAuth2ClientConfigurer> ClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); + this.clientRegistrationRepository = clientRegistrationRepository; return this; } @@ -119,6 +124,7 @@ public final class OAuth2ClientConfigurer> OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository); + this.authorizedClientRepository = authorizedClientRepository; return this; } @@ -283,8 +289,7 @@ public final class OAuth2ClientConfigurer> if (this.authorizationRequestResolver != null) { return this.authorizationRequestResolver; } - ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils - .getClientRegistrationRepository(getBuilder()); + ClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(getBuilder()); return new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); } @@ -292,8 +297,8 @@ public final class OAuth2ClientConfigurer> private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B builder) { AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class); OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter( - OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), - OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder), authenticationManager); + getClientRegistrationRepository(builder), getAuthorizedClientRepository(builder), + authenticationManager); if (this.authorizationRequestRepository != null) { authorizationCodeGrantFilter.setAuthorizationRequestRepository(this.authorizationRequestRepository); } @@ -315,6 +320,18 @@ public final class OAuth2ClientConfigurer> return (bean != null) ? bean : new DefaultAuthorizationCodeTokenResponseClient(); } + private ClientRegistrationRepository getClientRegistrationRepository(B builder) { + return (OAuth2ClientConfigurer.this.clientRegistrationRepository != null) + ? OAuth2ClientConfigurer.this.clientRegistrationRepository + : OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder); + } + + private OAuth2AuthorizedClientRepository getAuthorizedClientRepository(B builder) { + return (OAuth2ClientConfigurer.this.authorizedClientRepository != null) + ? OAuth2ClientConfigurer.this.authorizedClientRepository + : OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder); + } + @SuppressWarnings("unchecked") private T getBeanOrNull(ResolvableType type) { ApplicationContext context = getBuilder().getSharedObject(ApplicationContext.class); diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index a6b5f7c52b..913a8f1211 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -172,6 +172,10 @@ public final class OAuth2LoginConfigurer> private String loginProcessingUrl = OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI; + private ClientRegistrationRepository clientRegistrationRepository; + + private OAuth2AuthorizedClientRepository authorizedClientRepository; + /** * Sets the repository of client registrations. * @param clientRegistrationRepository the repository of client registrations @@ -181,6 +185,7 @@ public final class OAuth2LoginConfigurer> ClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); + this.clientRegistrationRepository = clientRegistrationRepository; return this; } @@ -194,6 +199,7 @@ public final class OAuth2LoginConfigurer> OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository); + this.authorizedClientRepository = authorizedClientRepository; return this; } @@ -339,8 +345,7 @@ public final class OAuth2LoginConfigurer> @Override public void init(B http) throws Exception { OAuth2LoginAuthenticationFilter authenticationFilter = new OAuth2LoginAuthenticationFilter( - OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), - OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(this.getBuilder()), this.loginProcessingUrl); + this.getClientRegistrationRepository(), this.getAuthorizedClientRepository(), this.loginProcessingUrl); authenticationFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); this.setAuthenticationFilter(authenticationFilter); super.loginProcessingUrl(this.loginProcessingUrl); @@ -406,8 +411,7 @@ public final class OAuth2LoginConfigurer> authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; } authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter( - OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), - authorizationRequestBaseUri); + this.getClientRegistrationRepository(), authorizationRequestBaseUri); } if (this.authorizationEndpointConfig.authorizationRequestRepository != null) { authorizationRequestFilter @@ -439,6 +443,16 @@ public final class OAuth2LoginConfigurer> return new AntPathRequestMatcher(loginProcessingUrl); } + private ClientRegistrationRepository getClientRegistrationRepository() { + return (this.clientRegistrationRepository != null) ? this.clientRegistrationRepository + : OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()); + } + + private OAuth2AuthorizedClientRepository getAuthorizedClientRepository() { + return (this.authorizedClientRepository != null) ? this.authorizedClientRepository + : OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(this.getBuilder()); + } + @SuppressWarnings("unchecked") private JwtDecoderFactory getJwtDecoderFactoryBean() { ResolvableType type = ResolvableType.forClassWithGenerics(JwtDecoderFactory.class, ClientRegistration.class); @@ -529,8 +543,7 @@ public final class OAuth2LoginConfigurer> @SuppressWarnings("unchecked") private Map getLoginLinks() { Iterable clientRegistrations = null; - ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils - .getClientRegistrationRepository(this.getBuilder()); + ClientRegistrationRepository clientRegistrationRepository = this.getClientRegistrationRepository(); ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class); if (type != ResolvableType.NONE && ClientRegistration.class.isAssignableFrom(type.resolveGenerics()[0])) { clientRegistrations = (Iterable) clientRegistrationRepository; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java index 41e74807cd..83dacaa265 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -75,6 +75,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; @@ -285,6 +286,49 @@ public class OAuth2ClientConfigurerTests { verify(authorizationRedirectStrategy).sendRedirect(any(), any(), anyString()); } + @Test + public void configureWhenOAuth2LoginBeansConfiguredThenNotShared() throws Exception { + this.spring.register(OAuth2ClientConfigWithOAuth2Login.class).autowire(); + // Setup the Authorization Request in the session + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId()); + // @formatter:off + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri()) + .clientId(this.registration1.getClientId()) + .redirectUri("http://localhost/client-1") + .state("state") + .attributes(attributes) + .build(); + // @formatter:on + AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); + MockHttpSession session = (MockHttpSession) request.getSession(); + String principalName = "user1"; + TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); + // @formatter:off + MockHttpServletRequestBuilder clientRequest = get("/client-1") + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, "state") + .with(authentication(authentication)) + .session(session); + this.mockMvc.perform(clientRequest) + .andExpect(status().is3xxRedirection()) + .andExpect(redirectedUrl("http://localhost/client-1")); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = authorizedClientRepository + .loadAuthorizedClient(this.registration1.getRegistrationId(), authentication, request); + assertThat(authorizedClient).isNotNull(); + // Ensure shared objects set for OAuth2 Client are not used + ClientRegistrationRepository clientRegistrationRepository = this.spring.getContext() + .getBean(ClientRegistrationRepository.class); + OAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext() + .getBean(OAuth2AuthorizedClientRepository.class); + verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository); + } + @EnableWebSecurity @Configuration @EnableWebMvc @@ -362,4 +406,51 @@ public class OAuth2ClientConfigurerTests { } + @Configuration + @EnableWebSecurity + @EnableWebMvc + static class OAuth2ClientConfigWithOAuth2Login { + + private final ClientRegistrationRepository clientRegistrationRepository = mock( + ClientRegistrationRepository.class); + + private final OAuth2AuthorizedClientRepository authorizedClientRepository = mock( + OAuth2AuthorizedClientRepository.class); + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeHttpRequests((authorize) -> authorize + .anyRequest().authenticated() + ) + .oauth2Client((oauth2Client) -> oauth2Client + .clientRegistrationRepository(OAuth2ClientConfigurerTests.clientRegistrationRepository) + .authorizedClientService(OAuth2ClientConfigurerTests.authorizedClientService) + .authorizationCodeGrant((authorizationCode) -> authorizationCode + .authorizationRequestResolver(authorizationRequestResolver) + .authorizationRedirectStrategy(authorizationRedirectStrategy) + .accessTokenResponseClient(accessTokenResponseClient) + ) + ) + .oauth2Login((oauth2Login) -> oauth2Login + .clientRegistrationRepository(this.clientRegistrationRepository) + .authorizedClientRepository(this.authorizedClientRepository) + ); + // @formatter:on + return http.build(); + } + + @Bean + ClientRegistrationRepository clientRegistrationRepository() { + return this.clientRegistrationRepository; + } + + @Bean + OAuth2AuthorizedClientRepository authorizedClientRepository() { + return this.authorizedClientRepository; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index b56d047a5f..dfe6fea28f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,7 +73,9 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; +import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -115,6 +117,7 @@ import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.setAuthentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; @@ -669,6 +672,30 @@ public class OAuth2LoginConfigurerTests { .collect(Collectors.toList())).isEmpty(); } + @Test + public void oidcLoginWhenOAuth2ClientBeansConfiguredThenNotShared() throws Exception { + this.spring.register(OAuth2LoginConfigWithOAuth2Client.class, JwtDecoderFactoryConfig.class).autowire(); + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); + Authentication authentication = this.securityContextRepository + .loadContext(new HttpRequestResponseHolder(this.request, this.response)) + .getAuthentication(); + assertThat(authentication.getAuthorities()).hasSize(1); + assertThat(authentication.getAuthorities()).first() + .isInstanceOf(OidcUserAuthority.class) + .hasToString("OIDC_USER"); + + // Ensure shared objects set for OAuth2 Client are not used + ClientRegistrationRepository clientRegistrationRepository = this.spring.getContext() + .getBean(ClientRegistrationRepository.class); + OAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext() + .getBean(OAuth2AuthorizedClientRepository.class); + verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository); + } + private void loadConfig(Class... configs) { AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); applicationContext.register(configs); @@ -1192,6 +1219,45 @@ public class OAuth2LoginConfigurerTests { } + @Configuration + @EnableWebSecurity + static class OAuth2LoginConfigWithOAuth2Client extends CommonLambdaSecurityFilterChainConfig { + + private final ClientRegistrationRepository clientRegistrationRepository = mock( + ClientRegistrationRepository.class); + + private final OAuth2AuthorizedClientRepository authorizedClientRepository = mock( + OAuth2AuthorizedClientRepository.class); + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((oauth2Login) -> oauth2Login + .clientRegistrationRepository( + new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) + .authorizedClientRepository(new HttpSessionOAuth2AuthorizedClientRepository()) + ) + .oauth2Client((oauth2Client) -> oauth2Client + .clientRegistrationRepository(this.clientRegistrationRepository) + .authorizedClientRepository(this.authorizedClientRepository) + ); + // @formatter:on + return super.configureFilterChain(http); + } + + @Bean + ClientRegistrationRepository clientRegistrationRepository() { + return this.clientRegistrationRepository; + } + + @Bean + OAuth2AuthorizedClientRepository authorizedClientRepository() { + return this.authorizedClientRepository; + } + + } + private abstract static class CommonSecurityFilterChainConfig { SecurityFilterChain configureFilterChain(HttpSecurity http) throws Exception {