diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index f340063000..810ddf010e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -15,7 +15,6 @@ */ package org.springframework.security.config.annotation.web.configurers.oauth2.client; -import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.context.ApplicationContext; import org.springframework.core.ResolvableType; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; @@ -39,15 +38,12 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestVariablesExtractor; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; import java.net.URI; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; -import java.util.List; +import java.util.HashMap; import java.util.Map; -import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter.REGISTRATION_ID_URI_VARIABLE_NAME; @@ -75,7 +71,6 @@ public final class OAuth2LoginConfigurer> exten public OAuth2LoginConfigurer clients(ClientRegistration... clientRegistrations) { Assert.notEmpty(clientRegistrations, "clientRegistrations cannot be empty"); - this.getBuilder().setSharedObject(ClientRegistration[].class, clientRegistrations); return this.clients(new InMemoryClientRegistrationRepository(Arrays.asList(clientRegistrations))); } @@ -230,56 +225,24 @@ public final class OAuth2LoginConfigurer> exten } private static > ClientRegistrationRepository getDefaultClientRegistrationRepository(H http) { - List clientRegistrations = getClientRegistrations(http); - if (!CollectionUtils.isEmpty(clientRegistrations)) { - return new InMemoryClientRegistrationRepository(clientRegistrations); - } return http.getSharedObject(ApplicationContext.class).getBean(ClientRegistrationRepository.class); } - private static > List getClientRegistrations(H http) { - ClientRegistration[] clientRegistrations = http.getSharedObject(ClientRegistration[].class); - if (clientRegistrations != null) { - return Arrays.asList(clientRegistrations); - } - - List clientRegistrationsList = new ArrayList<>(); - - // Check context for type -> Collection - ResolvableType clientRegistrationsType = ResolvableType.forClassWithGenerics( - Collection.class, ClientRegistration.class); - Map clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors( - http.getSharedObject(ApplicationContext.class), - clientRegistrationsType.resolve(Collection.class)); - clientRegistrationsMap.values().stream() - .filter(col -> Collection.class.isAssignableFrom(col.getClass())) - .filter(col -> ((Collection) col).stream() - .anyMatch(e -> ClientRegistration.class.isAssignableFrom(e.getClass()))) - .flatMap(col -> ((Collection) col).stream()) - .forEach(e -> clientRegistrationsList.add((ClientRegistration)e)); - if (!clientRegistrationsList.isEmpty()) { - return clientRegistrationsList; - } - - // Check context for type -> ClientRegistration[] - clientRegistrationsType = ResolvableType.forClass(ClientRegistration[].class); - clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors( - http.getSharedObject(ApplicationContext.class), - clientRegistrationsType.resolve(ClientRegistration[].class)); - clientRegistrationsMap.values().stream() - .flatMap(array -> Arrays.stream((ClientRegistration[])array)) - .forEach(clientRegistrationsList::add); - - return clientRegistrationsList; - } - private void initDefaultLoginFilter(H http) { DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http.getSharedObject(DefaultLoginPageGeneratingFilter.class); if (loginPageGeneratingFilter == null || this.authorizationCodeAuthenticationFilterConfigurer.isCustomLoginPage()) { return; } - List clientRegistrations = getClientRegistrations(http); - if (CollectionUtils.isEmpty(clientRegistrations)) { + + Iterable clientRegistrations = null; + ClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(http); + ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class); + if (type != ResolvableType.NONE) { + if (Stream.of(type.resolveGenerics()).anyMatch(ClientRegistration.class::isAssignableFrom)) { + clientRegistrations = (Iterable) clientRegistrationRepository; + } + } + if (clientRegistrations == null) { return; } @@ -298,10 +261,9 @@ public final class OAuth2LoginConfigurer> exten authorizationRequestBaseUri = AuthorizationCodeRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI; } - Map oauth2AuthenticationUrlToClientName = clientRegistrations.stream() - .collect(Collectors.toMap( - e -> authorizationRequestBaseUri + "/" + e.getRegistrationId(), - e -> e.getClientName())); + Map oauth2AuthenticationUrlToClientName = new HashMap<>(); + clientRegistrations.forEach(registration -> oauth2AuthenticationUrlToClientName.put( + authorizationRequestBaseUri + "/" + registration.getRegistrationId(), registration.getClientName())); loginPageGeneratingFilter.setOauth2LoginEnabled(true); loginPageGeneratingFilter.setOauth2AuthenticationUrlToClientName(oauth2AuthenticationUrlToClientName); loginPageGeneratingFilter.setLoginPageUrl(this.authorizationCodeAuthenticationFilterConfigurer.getLoginUrl()); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java index b561418e86..a7f319b5a5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java @@ -19,6 +19,7 @@ import org.springframework.util.Assert; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; @@ -29,7 +30,7 @@ import java.util.Map; * @since 5.0 * @see ClientRegistration */ -public final class InMemoryClientRegistrationRepository implements ClientRegistrationRepository { +public final class InMemoryClientRegistrationRepository implements ClientRegistrationRepository, Iterable { private final ClientRegistrationIdentifierStrategy identifierStrategy = new RegistrationIdIdentifierStrategy(); private final Map registrations; @@ -54,4 +55,9 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr .findFirst() .orElse(null); } + + @Override + public Iterator iterator() { + return Collections.unmodifiableCollection(this.registrations.values()).iterator(); + } } diff --git a/samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java b/samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java index ad85bbac0e..7cc2ed499c 100644 --- a/samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java +++ b/samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java @@ -39,7 +39,6 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.user.OAuth2UserService; import org.springframework.security.oauth2.client.web.AuthorizationCodeAuthenticationProcessingFilter; import org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter; @@ -58,7 +57,6 @@ import org.springframework.web.util.UriComponentsBuilder; import java.net.URI; import java.net.URL; import java.net.URLDecoder; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -91,8 +89,6 @@ public class OAuth2LoginApplicationTests { private WebClient webClient; @Autowired - private ClientRegistration[] clientRegistrations; - private ClientRegistrationRepository clientRegistrationRepository; private ClientRegistration googleClientRegistration; @@ -103,7 +99,6 @@ public class OAuth2LoginApplicationTests { @Before public void setup() { this.webClient.getCookieManager().clearCookies(); - this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(Arrays.asList(this.clientRegistrations)); this.googleClientRegistration = this.clientRegistrationRepository.findByRegistrationId("google"); this.githubClientRegistration = this.clientRegistrationRepository.findByRegistrationId("github"); this.facebookClientRegistration = this.clientRegistrationRepository.findByRegistrationId("facebook"); diff --git a/samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/ClientRegistrationAutoConfiguration.java b/samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/ClientRegistrationAutoConfiguration.java index ace33e0813..6f92c63859 100644 --- a/samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/ClientRegistrationAutoConfiguration.java +++ b/samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/ClientRegistrationAutoConfiguration.java @@ -40,6 +40,8 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.type.AnnotatedTypeMetadata; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationProperties; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.util.CollectionUtils; import java.util.ArrayList; @@ -54,8 +56,8 @@ import java.util.stream.Collectors; */ @Configuration @ConditionalOnWebApplication -@ConditionalOnClass(ClientRegistration.class) -@ConditionalOnMissingBean(ClientRegistration.class) +@ConditionalOnClass(ClientRegistrationRepository.class) +@ConditionalOnMissingBean(ClientRegistrationRepository.class) @AutoConfigureBefore(SecurityAutoConfiguration.class) public class ClientRegistrationAutoConfiguration { private static final String CLIENTS_DEFAULTS_RESOURCE = "META-INF/oauth2-clients-defaults.yml"; @@ -72,7 +74,7 @@ public class ClientRegistrationAutoConfiguration { } @Bean - public ClientRegistration[] clientRegistrations() { + public ClientRegistrationRepository clientRegistrations() { MutablePropertySources propertySources = ((ConfigurableEnvironment) this.environment).getPropertySources(); Properties clientsDefaultProperties = this.getClientsDefaultProperties(); if (clientsDefaultProperties != null) { @@ -93,7 +95,7 @@ public class ClientRegistrationAutoConfiguration { clientRegistrations.add(clientRegistration); } - return clientRegistrations.toArray(new ClientRegistration[0]); + return new InMemoryClientRegistrationRepository(clientRegistrations); } private Properties getClientsDefaultProperties() { diff --git a/samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/OAuth2LoginAutoConfiguration.java b/samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/OAuth2LoginAutoConfiguration.java index ba14f03a53..02d75cdc7e 100644 --- a/samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/OAuth2LoginAutoConfiguration.java +++ b/samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/OAuth2LoginAutoConfiguration.java @@ -27,7 +27,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfiguration; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; -import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; /** * @author Joe Grandja @@ -36,7 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio @ConditionalOnWebApplication @ConditionalOnClass(EnableWebSecurity.class) @ConditionalOnMissingBean(WebSecurityConfiguration.class) -@ConditionalOnBean(ClientRegistration[].class) +@ConditionalOnBean(ClientRegistrationRepository.class) @AutoConfigureBefore(SecurityAutoConfiguration.class) @AutoConfigureAfter(ClientRegistrationAutoConfiguration.class) public class OAuth2LoginAutoConfiguration {