Remove ClientRegistrationRepository.getRegistrations()

Fixes gh-4582
This commit is contained in:
Joe Grandja 2017-09-28 06:08:56 -04:00
parent 3217582805
commit 8448a54678
10 changed files with 121 additions and 118 deletions

View File

@ -18,13 +18,10 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.cl
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.configurers.AbstractAuthenticationFilterConfigurer;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.jwt.JwtDecoder;
import org.springframework.security.jwt.nimbus.NimbusJwtDecoderJwkSupport;
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.jwt.DefaultProviderJwtDecoderRegistry;
import org.springframework.security.oauth2.client.authentication.jwt.ProviderJwtDecoderRegistry;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.authentication.jwt.JwtDecoderRegistry;
import org.springframework.security.oauth2.client.authentication.jwt.nimbus.NimbusJwtDecoderRegistry;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.token.InMemoryAccessTokenRepository;
import org.springframework.security.oauth2.client.token.SecurityTokenRepository;
@ -36,19 +33,13 @@ import org.springframework.security.oauth2.client.web.AuthorizationCodeAuthentic
import org.springframework.security.oauth2.client.web.AuthorizationGrantTokenExchanger;
import org.springframework.security.oauth2.client.web.nimbus.NimbusAuthorizationCodeTokenExchanger;
import org.springframework.security.oauth2.core.AccessToken;
import org.springframework.security.oauth2.core.provider.DefaultProviderMetadata;
import org.springframework.security.oauth2.core.provider.ProviderMetadata;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.oidc.client.user.OidcUserService;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.matcher.RequestVariablesExtractor;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@ -63,6 +54,7 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
private R authorizationResponseMatcher;
private AuthorizationGrantTokenExchanger<AuthorizationCodeAuthenticationToken> authorizationCodeTokenExchanger;
private SecurityTokenRepository<AccessToken> accessTokenRepository;
private JwtDecoderRegistry jwtDecoderRegistry;
private OAuth2UserService userInfoService;
private Map<URI, Class<? extends OAuth2User>> customUserTypes = new HashMap<>();
private GrantedAuthoritiesMapper userAuthoritiesMapper;
@ -91,6 +83,12 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
return this;
}
AuthorizationCodeAuthenticationFilterConfigurer<H, R> jwtDecoderRegistry(JwtDecoderRegistry jwtDecoderRegistry) {
Assert.notNull(jwtDecoderRegistry, "jwtDecoderRegistry cannot be null");
this.jwtDecoderRegistry = jwtDecoderRegistry;
return this;
}
AuthorizationCodeAuthenticationFilterConfigurer<H, R> userInfoService(OAuth2UserService userInfoService) {
Assert.notNull(userInfoService, "userInfoService cannot be null");
this.userInfoService = userInfoService;
@ -112,7 +110,6 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
AuthorizationCodeAuthenticationFilterConfigurer<H, R> clientRegistrationRepository(ClientRegistrationRepository clientRegistrationRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notEmpty(clientRegistrationRepository.getRegistrations(), "clientRegistrationRepository cannot be empty");
this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository);
return this;
}
@ -129,7 +126,7 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
public void init(H http) throws Exception {
AuthorizationCodeAuthenticationProvider authenticationProvider = new AuthorizationCodeAuthenticationProvider(
this.getAuthorizationCodeTokenExchanger(), this.getAccessTokenRepository(),
this.getProviderJwtDecoderRegistry(), this.getUserInfoService());
this.getJwtDecoderRegistry(), this.getUserInfoService());
if (this.userAuthoritiesMapper != null) {
authenticationProvider.setAuthoritiesMapper(this.userAuthoritiesMapper);
}
@ -168,48 +165,18 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
return this.accessTokenRepository;
}
private ProviderJwtDecoderRegistry getProviderJwtDecoderRegistry() {
Map<ProviderMetadata, JwtDecoder> jwtDecoders = new HashMap<>();
ClientRegistrationRepository clientRegistrationRepository = OAuth2LoginConfigurer.getClientRegistrationRepository(this.getBuilder());
clientRegistrationRepository.getRegistrations().forEach(registration -> {
ClientRegistration.ProviderDetails providerDetails = registration.getProviderDetails();
if (StringUtils.hasText(providerDetails.getJwkSetUri())) {
DefaultProviderMetadata providerMetadata = new DefaultProviderMetadata();
// Default the Issuer to the host of the Authorization Endpoint
providerMetadata.setIssuer(this.toURL(
UriComponentsBuilder
.fromHttpUrl(providerDetails.getAuthorizationUri())
.replacePath(null)
.toUriString()
));
providerMetadata.setAuthorizationEndpoint(this.toURL(providerDetails.getAuthorizationUri()));
providerMetadata.setTokenEndpoint(this.toURL(providerDetails.getTokenUri()));
providerMetadata.setUserInfoEndpoint(this.toURL(providerDetails.getUserInfoEndpoint().getUri()));
providerMetadata.setJwkSetUri(this.toURL(providerDetails.getJwkSetUri()));
NimbusJwtDecoderJwkSupport nimbusJwtDecoderJwkSupport =
new NimbusJwtDecoderJwkSupport(providerDetails.getJwkSetUri());
jwtDecoders.put(providerMetadata, nimbusJwtDecoderJwkSupport);
}
});
return new DefaultProviderJwtDecoderRegistry(jwtDecoders);
}
private boolean isOidcClientRegistered() {
ClientRegistrationRepository clientRegistrationRepository = OAuth2LoginConfigurer.getClientRegistrationRepository(this.getBuilder());
return clientRegistrationRepository.getRegistrations()
.stream()
.anyMatch(registration ->
registration.getScope().stream().anyMatch(scope -> scope.equalsIgnoreCase("openid")));
private JwtDecoderRegistry getJwtDecoderRegistry() {
if (this.jwtDecoderRegistry == null) {
this.jwtDecoderRegistry = new NimbusJwtDecoderRegistry();
}
return this.jwtDecoderRegistry;
}
private OAuth2UserService getUserInfoService() {
if (this.userInfoService == null) {
List<OAuth2UserService> oauth2UserServices = new ArrayList<>();
oauth2UserServices.add(new DefaultOAuth2UserService());
if (this.isOidcClientRegistered()) {
oauth2UserServices.add(new OidcUserService());
}
oauth2UserServices.add(new OidcUserService());
if (!this.customUserTypes.isEmpty()) {
oauth2UserServices.add(new CustomUserTypesOAuth2UserService(this.customUserTypes));
}
@ -217,15 +184,4 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
}
return this.userInfoService;
}
private URL toURL(String urlStr) {
if (!StringUtils.hasText(urlStr)) {
return null;
}
try {
return new URL(urlStr);
} catch (MalformedURLException ex) {
throw new IllegalArgumentException("Failed to convert '" + urlStr + "' to a URL: " + ex.getMessage(), ex);
}
}
}

View File

@ -17,10 +17,10 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.cl
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.AuthorizationRequestUriBuilder;
import org.springframework.security.oauth2.client.web.DefaultAuthorizationRequestUriBuilder;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.matcher.RequestVariablesExtractor;
import org.springframework.util.Assert;
@ -48,7 +48,6 @@ final class AuthorizationCodeRequestRedirectFilterConfigurer<H extends HttpSecur
AuthorizationCodeRequestRedirectFilterConfigurer<H, R> clientRegistrationRepository(ClientRegistrationRepository clientRegistrationRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notEmpty(clientRegistrationRepository.getRegistrations(), "clientRegistrationRepository cannot be empty");
this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository);
return this;
}

View File

@ -15,7 +15,9 @@
*/
package org.springframework.security.config.annotation.web.configurers.oauth2.client;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
@ -40,7 +42,10 @@ import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@ -70,12 +75,12 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
public OAuth2LoginConfigurer<H> clients(ClientRegistration... clientRegistrations) {
Assert.notEmpty(clientRegistrations, "clientRegistrations cannot be empty");
this.getBuilder().setSharedObject(ClientRegistration[].class, clientRegistrations);
return this.clients(new InMemoryClientRegistrationRepository(Arrays.asList(clientRegistrations)));
}
public OAuth2LoginConfigurer<H> clients(ClientRegistrationRepository clientRegistrationRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notEmpty(clientRegistrationRepository.getRegistrations(), "clientRegistrationRepository cannot be empty");
this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository);
return this;
}
@ -225,38 +230,81 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
}
private static <H extends HttpSecurityBuilder<H>> ClientRegistrationRepository getDefaultClientRegistrationRepository(H http) {
List<ClientRegistration> clientRegistrations = getClientRegistrations(http);
if (!CollectionUtils.isEmpty(clientRegistrations)) {
return new InMemoryClientRegistrationRepository(clientRegistrations);
}
return http.getSharedObject(ApplicationContext.class).getBean(ClientRegistrationRepository.class);
}
private static <H extends HttpSecurityBuilder<H>> List<ClientRegistration> getClientRegistrations(H http) {
ClientRegistration[] clientRegistrations = http.getSharedObject(ClientRegistration[].class);
if (clientRegistrations != null) {
return Arrays.asList(clientRegistrations);
}
List<ClientRegistration> clientRegistrationsList = new ArrayList<>();
// Check context for type -> Collection<ClientRegistration>
ResolvableType clientRegistrationsType = ResolvableType.forClassWithGenerics(
Collection.class, ClientRegistration.class);
Map<String, ?> clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
http.getSharedObject(ApplicationContext.class),
clientRegistrationsType.resolve(Collection.class));
clientRegistrationsMap.values().stream()
.filter(col -> Collection.class.isAssignableFrom(col.getClass()))
.filter(col -> ((Collection) col).stream()
.anyMatch(e -> ClientRegistration.class.isAssignableFrom(e.getClass())))
.flatMap(col -> ((Collection) col).stream())
.forEach(e -> clientRegistrationsList.add((ClientRegistration)e));
if (!clientRegistrationsList.isEmpty()) {
return clientRegistrationsList;
}
// Check context for type -> ClientRegistration[]
clientRegistrationsType = ResolvableType.forClass(ClientRegistration[].class);
clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
http.getSharedObject(ApplicationContext.class),
clientRegistrationsType.resolve(ClientRegistration[].class));
clientRegistrationsMap.values().stream()
.flatMap(array -> Arrays.stream((ClientRegistration[])array))
.forEach(clientRegistrationsList::add);
return clientRegistrationsList;
}
private void initDefaultLoginFilter(H http) {
DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http.getSharedObject(DefaultLoginPageGeneratingFilter.class);
if (loginPageGeneratingFilter != null && !this.authorizationCodeAuthenticationFilterConfigurer.isCustomLoginPage()) {
ClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(this.getBuilder());
if (!CollectionUtils.isEmpty(clientRegistrationRepository.getRegistrations())) {
String authorizationRequestBaseUri;
RequestMatcher authorizationRequestMatcher = OAuth2LoginConfigurer.this.authorizationCodeRequestRedirectFilterConfigurer.getAuthorizationRequestMatcher();
if (authorizationRequestMatcher != null && AntPathRequestMatcher.class.isAssignableFrom(authorizationRequestMatcher.getClass())) {
String authorizationRequestPattern = ((AntPathRequestMatcher)authorizationRequestMatcher).getPattern();
String registrationIdTemplateVariable = "{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}";
if (authorizationRequestPattern.endsWith(registrationIdTemplateVariable)) {
authorizationRequestBaseUri = authorizationRequestPattern.substring(
0, authorizationRequestPattern.length() - registrationIdTemplateVariable.length() - 1);
} else {
authorizationRequestBaseUri = authorizationRequestPattern;
}
} else {
authorizationRequestBaseUri = AuthorizationCodeRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
}
Map<String, String> oauth2AuthenticationUrlToClientName = clientRegistrationRepository.getRegistrations().stream()
.collect(Collectors.toMap(
e -> authorizationRequestBaseUri + "/" + e.getRegistrationId(),
e -> e.getClientName()));
loginPageGeneratingFilter.setOauth2LoginEnabled(true);
loginPageGeneratingFilter.setOauth2AuthenticationUrlToClientName(oauth2AuthenticationUrlToClientName);
loginPageGeneratingFilter.setLoginPageUrl(this.authorizationCodeAuthenticationFilterConfigurer.getLoginUrl());
loginPageGeneratingFilter.setFailureUrl(this.authorizationCodeAuthenticationFilterConfigurer.getLoginFailureUrl());
}
if (loginPageGeneratingFilter == null || this.authorizationCodeAuthenticationFilterConfigurer.isCustomLoginPage()) {
return;
}
List<ClientRegistration> clientRegistrations = getClientRegistrations(http);
if (CollectionUtils.isEmpty(clientRegistrations)) {
return;
}
String authorizationRequestBaseUri;
RequestMatcher authorizationRequestMatcher = OAuth2LoginConfigurer.this.authorizationCodeRequestRedirectFilterConfigurer.getAuthorizationRequestMatcher();
if (authorizationRequestMatcher != null && AntPathRequestMatcher.class.isAssignableFrom(authorizationRequestMatcher.getClass())) {
String authorizationRequestPattern = ((AntPathRequestMatcher)authorizationRequestMatcher).getPattern();
String registrationIdTemplateVariable = "{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}";
if (authorizationRequestPattern.endsWith(registrationIdTemplateVariable)) {
authorizationRequestBaseUri = authorizationRequestPattern.substring(
0, authorizationRequestPattern.length() - registrationIdTemplateVariable.length() - 1);
} else {
authorizationRequestBaseUri = authorizationRequestPattern;
}
} else {
authorizationRequestBaseUri = AuthorizationCodeRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
}
Map<String, String> oauth2AuthenticationUrlToClientName = clientRegistrations.stream()
.collect(Collectors.toMap(
e -> authorizationRequestBaseUri + "/" + e.getRegistrationId(),
e -> e.getClientName()));
loginPageGeneratingFilter.setOauth2LoginEnabled(true);
loginPageGeneratingFilter.setOauth2AuthenticationUrlToClientName(oauth2AuthenticationUrlToClientName);
loginPageGeneratingFilter.setLoginPageUrl(this.authorizationCodeAuthenticationFilterConfigurer.getLoginUrl());
loginPageGeneratingFilter.setFailureUrl(this.authorizationCodeAuthenticationFilterConfigurer.getLoginFailureUrl());
}
}

View File

@ -24,7 +24,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
import org.springframework.security.core.authority.mapping.NullAuthoritiesMapper;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtDecoder;
import org.springframework.security.oauth2.client.authentication.jwt.ProviderJwtDecoderRegistry;
import org.springframework.security.oauth2.client.authentication.jwt.JwtDecoderRegistry;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.token.SecurityTokenRepository;
import org.springframework.security.oauth2.client.user.OAuth2UserService;
@ -89,23 +89,23 @@ import java.util.Collection;
public class AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
private final AuthorizationGrantTokenExchanger<AuthorizationCodeAuthenticationToken> authorizationCodeTokenExchanger;
private final SecurityTokenRepository<AccessToken> accessTokenRepository;
private final ProviderJwtDecoderRegistry providerJwtDecoderRegistry;
private final JwtDecoderRegistry jwtDecoderRegistry;
private final OAuth2UserService userInfoService;
private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper();
public AuthorizationCodeAuthenticationProvider(
AuthorizationGrantTokenExchanger<AuthorizationCodeAuthenticationToken> authorizationCodeTokenExchanger,
SecurityTokenRepository<AccessToken> accessTokenRepository,
ProviderJwtDecoderRegistry providerJwtDecoderRegistry,
JwtDecoderRegistry jwtDecoderRegistry,
OAuth2UserService userInfoService) {
Assert.notNull(authorizationCodeTokenExchanger, "authorizationCodeTokenExchanger cannot be null");
Assert.notNull(accessTokenRepository, "accessTokenRepository cannot be null");
Assert.notNull(providerJwtDecoderRegistry, "providerJwtDecoderRegistry cannot be null");
Assert.notNull(jwtDecoderRegistry, "jwtDecoderRegistry cannot be null");
Assert.notNull(userInfoService, "userInfoService cannot be null");
this.authorizationCodeTokenExchanger = authorizationCodeTokenExchanger;
this.accessTokenRepository = accessTokenRepository;
this.providerJwtDecoderRegistry = providerJwtDecoderRegistry;
this.jwtDecoderRegistry = jwtDecoderRegistry;
this.userInfoService = userInfoService;
}
@ -124,9 +124,9 @@ public class AuthorizationCodeAuthenticationProvider implements AuthenticationPr
IdToken idToken = null;
if (tokenResponse.getAdditionalParameters().containsKey(OidcParameter.ID_TOKEN)) {
JwtDecoder jwtDecoder = this.providerJwtDecoderRegistry.getJwtDecoder(clientRegistration.getProviderDetails().getJwkSetUri());
JwtDecoder jwtDecoder = this.jwtDecoderRegistry.getJwtDecoder(clientRegistration);
if (jwtDecoder == null) {
throw new IllegalArgumentException("Unable to find a registered JwtDecoder for the provider '" + clientRegistration.getProviderDetails().getTokenUri() +
throw new IllegalArgumentException("Unable to find a registered JwtDecoder for Client Registration: '" + clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured the JwkSet URI property.");
}
Jwt jwt = jwtDecoder.decode((String)tokenResponse.getAdditionalParameters().get(OidcParameter.ID_TOKEN));

View File

@ -37,6 +37,4 @@ public interface ClientRegistrationRepository {
ClientRegistration findByRegistrationId(String registrationId);
List<ClientRegistration> getRegistrations();
}

View File

@ -17,7 +17,6 @@ package org.springframework.security.oauth2.client.registration;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@ -64,9 +63,4 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr
.findFirst()
.orElse(null);
}
@Override
public List<ClientRegistration> getRegistrations() {
return new ArrayList<>(this.registrations.values());
}
}

View File

@ -192,7 +192,6 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
public final void setClientRegistrationRepository(ClientRegistrationRepository clientRegistrationRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notEmpty(clientRegistrationRepository.getRegistrations(), "clientRegistrationRepository cannot be empty");
this.clientRegistrationRepository = clientRegistrationRepository;
}

View File

@ -36,13 +36,14 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.client.web.AuthorizationCodeAuthenticationProcessingFilter;
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.AuthorizationGrantTokenExchanger;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.client.user.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthorizationCodeAuthenticationProcessingFilter;
import org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.AuthorizationGrantTokenExchanger;
import org.springframework.security.oauth2.core.AccessToken;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
@ -57,7 +58,14 @@ import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.net.URL;
import java.net.URLDecoder;
import java.util.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;
@ -83,6 +91,8 @@ public class OAuth2LoginApplicationTests {
private WebClient webClient;
@Autowired
private ClientRegistration[] clientRegistrations;
private ClientRegistrationRepository clientRegistrationRepository;
private ClientRegistration googleClientRegistration;
@ -93,6 +103,7 @@ public class OAuth2LoginApplicationTests {
@Before
public void setup() {
this.webClient.getCookieManager().clearCookies();
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(Arrays.asList(this.clientRegistrations));
this.googleClientRegistration = this.clientRegistrationRepository.findByRegistrationId("google");
this.githubClientRegistration = this.clientRegistrationRepository.findByRegistrationId("github");
this.facebookClientRegistration = this.clientRegistrationRepository.findByRegistrationId("facebook");

View File

@ -40,8 +40,6 @@ import org.springframework.core.io.ClassPathResource;
import org.springframework.core.type.AnnotatedTypeMetadata;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationProperties;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
@ -56,8 +54,8 @@ import java.util.stream.Collectors;
*/
@Configuration
@ConditionalOnWebApplication
@ConditionalOnClass(ClientRegistrationRepository.class)
@ConditionalOnMissingBean(ClientRegistrationRepository.class)
@ConditionalOnClass(ClientRegistration.class)
@ConditionalOnMissingBean(ClientRegistration.class)
@AutoConfigureBefore(SecurityAutoConfiguration.class)
public class ClientRegistrationAutoConfiguration {
private static final String CLIENTS_DEFAULTS_RESOURCE = "META-INF/oauth2-clients-defaults.yml";
@ -74,7 +72,7 @@ public class ClientRegistrationAutoConfiguration {
}
@Bean
public ClientRegistrationRepository clientRegistrationRepository() {
public ClientRegistration[] clientRegistrations() {
MutablePropertySources propertySources = ((ConfigurableEnvironment) this.environment).getPropertySources();
Properties clientsDefaultProperties = this.getClientsDefaultProperties();
if (clientsDefaultProperties != null) {
@ -95,7 +93,7 @@ public class ClientRegistrationAutoConfiguration {
clientRegistrations.add(clientRegistration);
}
return new InMemoryClientRegistrationRepository(clientRegistrations);
return clientRegistrations.toArray(new ClientRegistration[0]);
}
private Properties getClientsDefaultProperties() {

View File

@ -27,7 +27,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfiguration;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
/**
* @author Joe Grandja
@ -36,7 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
@ConditionalOnWebApplication
@ConditionalOnClass(EnableWebSecurity.class)
@ConditionalOnMissingBean(WebSecurityConfiguration.class)
@ConditionalOnBean(ClientRegistrationRepository.class)
@ConditionalOnBean(ClientRegistration[].class)
@AutoConfigureBefore(SecurityAutoConfiguration.class)
@AutoConfigureAfter(ClientRegistrationAutoConfiguration.class)
public class OAuth2LoginAutoConfiguration {