Default login page supports Iterable<ClientRegistration>

Fixes gh-4596
This commit is contained in:
Joe Grandja 2017-09-29 18:17:14 -04:00
parent 99f06ca58c
commit 66647070ab
5 changed files with 30 additions and 65 deletions

View File

@ -15,7 +15,6 @@
*/
package org.springframework.security.config.annotation.web.configurers.oauth2.client;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
@ -39,15 +38,12 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.matcher.RequestVariablesExtractor;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter.REGISTRATION_ID_URI_VARIABLE_NAME;
@ -75,7 +71,6 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
public OAuth2LoginConfigurer<H> clients(ClientRegistration... clientRegistrations) {
Assert.notEmpty(clientRegistrations, "clientRegistrations cannot be empty");
this.getBuilder().setSharedObject(ClientRegistration[].class, clientRegistrations);
return this.clients(new InMemoryClientRegistrationRepository(Arrays.asList(clientRegistrations)));
}
@ -230,56 +225,24 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
}
private static <H extends HttpSecurityBuilder<H>> ClientRegistrationRepository getDefaultClientRegistrationRepository(H http) {
List<ClientRegistration> clientRegistrations = getClientRegistrations(http);
if (!CollectionUtils.isEmpty(clientRegistrations)) {
return new InMemoryClientRegistrationRepository(clientRegistrations);
}
return http.getSharedObject(ApplicationContext.class).getBean(ClientRegistrationRepository.class);
}
private static <H extends HttpSecurityBuilder<H>> List<ClientRegistration> getClientRegistrations(H http) {
ClientRegistration[] clientRegistrations = http.getSharedObject(ClientRegistration[].class);
if (clientRegistrations != null) {
return Arrays.asList(clientRegistrations);
}
List<ClientRegistration> clientRegistrationsList = new ArrayList<>();
// Check context for type -> Collection<ClientRegistration>
ResolvableType clientRegistrationsType = ResolvableType.forClassWithGenerics(
Collection.class, ClientRegistration.class);
Map<String, ?> clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
http.getSharedObject(ApplicationContext.class),
clientRegistrationsType.resolve(Collection.class));
clientRegistrationsMap.values().stream()
.filter(col -> Collection.class.isAssignableFrom(col.getClass()))
.filter(col -> ((Collection) col).stream()
.anyMatch(e -> ClientRegistration.class.isAssignableFrom(e.getClass())))
.flatMap(col -> ((Collection) col).stream())
.forEach(e -> clientRegistrationsList.add((ClientRegistration)e));
if (!clientRegistrationsList.isEmpty()) {
return clientRegistrationsList;
}
// Check context for type -> ClientRegistration[]
clientRegistrationsType = ResolvableType.forClass(ClientRegistration[].class);
clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
http.getSharedObject(ApplicationContext.class),
clientRegistrationsType.resolve(ClientRegistration[].class));
clientRegistrationsMap.values().stream()
.flatMap(array -> Arrays.stream((ClientRegistration[])array))
.forEach(clientRegistrationsList::add);
return clientRegistrationsList;
}
private void initDefaultLoginFilter(H http) {
DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http.getSharedObject(DefaultLoginPageGeneratingFilter.class);
if (loginPageGeneratingFilter == null || this.authorizationCodeAuthenticationFilterConfigurer.isCustomLoginPage()) {
return;
}
List<ClientRegistration> clientRegistrations = getClientRegistrations(http);
if (CollectionUtils.isEmpty(clientRegistrations)) {
Iterable<ClientRegistration> clientRegistrations = null;
ClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(http);
ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class);
if (type != ResolvableType.NONE) {
if (Stream.of(type.resolveGenerics()).anyMatch(ClientRegistration.class::isAssignableFrom)) {
clientRegistrations = (Iterable<ClientRegistration>) clientRegistrationRepository;
}
}
if (clientRegistrations == null) {
return;
}
@ -298,10 +261,9 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
authorizationRequestBaseUri = AuthorizationCodeRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
}
Map<String, String> oauth2AuthenticationUrlToClientName = clientRegistrations.stream()
.collect(Collectors.toMap(
e -> authorizationRequestBaseUri + "/" + e.getRegistrationId(),
e -> e.getClientName()));
Map<String, String> oauth2AuthenticationUrlToClientName = new HashMap<>();
clientRegistrations.forEach(registration -> oauth2AuthenticationUrlToClientName.put(
authorizationRequestBaseUri + "/" + registration.getRegistrationId(), registration.getClientName()));
loginPageGeneratingFilter.setOauth2LoginEnabled(true);
loginPageGeneratingFilter.setOauth2AuthenticationUrlToClientName(oauth2AuthenticationUrlToClientName);
loginPageGeneratingFilter.setLoginPageUrl(this.authorizationCodeAuthenticationFilterConfigurer.getLoginUrl());

View File

@ -19,6 +19,7 @@ import org.springframework.util.Assert;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
@ -29,7 +30,7 @@ import java.util.Map;
* @since 5.0
* @see ClientRegistration
*/
public final class InMemoryClientRegistrationRepository implements ClientRegistrationRepository {
public final class InMemoryClientRegistrationRepository implements ClientRegistrationRepository, Iterable<ClientRegistration> {
private final ClientRegistrationIdentifierStrategy<String> identifierStrategy = new RegistrationIdIdentifierStrategy();
private final Map<String, ClientRegistration> registrations;
@ -54,4 +55,9 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr
.findFirst()
.orElse(null);
}
@Override
public Iterator<ClientRegistration> iterator() {
return Collections.unmodifiableCollection(this.registrations.values()).iterator();
}
}

View File

@ -39,7 +39,6 @@ import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.client.user.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthorizationCodeAuthenticationProcessingFilter;
import org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter;
@ -58,7 +57,6 @@ import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.net.URL;
import java.net.URLDecoder;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@ -91,8 +89,6 @@ public class OAuth2LoginApplicationTests {
private WebClient webClient;
@Autowired
private ClientRegistration[] clientRegistrations;
private ClientRegistrationRepository clientRegistrationRepository;
private ClientRegistration googleClientRegistration;
@ -103,7 +99,6 @@ public class OAuth2LoginApplicationTests {
@Before
public void setup() {
this.webClient.getCookieManager().clearCookies();
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(Arrays.asList(this.clientRegistrations));
this.googleClientRegistration = this.clientRegistrationRepository.findByRegistrationId("google");
this.githubClientRegistration = this.clientRegistrationRepository.findByRegistrationId("github");
this.facebookClientRegistration = this.clientRegistrationRepository.findByRegistrationId("facebook");

View File

@ -40,6 +40,8 @@ import org.springframework.core.io.ClassPathResource;
import org.springframework.core.type.AnnotatedTypeMetadata;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationProperties;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
@ -54,8 +56,8 @@ import java.util.stream.Collectors;
*/
@Configuration
@ConditionalOnWebApplication
@ConditionalOnClass(ClientRegistration.class)
@ConditionalOnMissingBean(ClientRegistration.class)
@ConditionalOnClass(ClientRegistrationRepository.class)
@ConditionalOnMissingBean(ClientRegistrationRepository.class)
@AutoConfigureBefore(SecurityAutoConfiguration.class)
public class ClientRegistrationAutoConfiguration {
private static final String CLIENTS_DEFAULTS_RESOURCE = "META-INF/oauth2-clients-defaults.yml";
@ -72,7 +74,7 @@ public class ClientRegistrationAutoConfiguration {
}
@Bean
public ClientRegistration[] clientRegistrations() {
public ClientRegistrationRepository clientRegistrations() {
MutablePropertySources propertySources = ((ConfigurableEnvironment) this.environment).getPropertySources();
Properties clientsDefaultProperties = this.getClientsDefaultProperties();
if (clientsDefaultProperties != null) {
@ -93,7 +95,7 @@ public class ClientRegistrationAutoConfiguration {
clientRegistrations.add(clientRegistration);
}
return clientRegistrations.toArray(new ClientRegistration[0]);
return new InMemoryClientRegistrationRepository(clientRegistrations);
}
private Properties getClientsDefaultProperties() {

View File

@ -27,7 +27,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfiguration;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
/**
* @author Joe Grandja
@ -36,7 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
@ConditionalOnWebApplication
@ConditionalOnClass(EnableWebSecurity.class)
@ConditionalOnMissingBean(WebSecurityConfiguration.class)
@ConditionalOnBean(ClientRegistration[].class)
@ConditionalOnBean(ClientRegistrationRepository.class)
@AutoConfigureBefore(SecurityAutoConfiguration.class)
@AutoConfigureAfter(ClientRegistrationAutoConfiguration.class)
public class OAuth2LoginAutoConfiguration {