Support configurable JwtDecoder for IdToken verification

Fixes gh-5717
This commit is contained in:
Joe Grandja 2018-12-12 10:28:22 -05:00 committed by Rob Winch
parent be23ab8114
commit 8f4f52edb9
12 changed files with 445 additions and 130 deletions

View File

@ -16,6 +16,7 @@
package org.springframework.security.config.annotation.web.configurers.oauth2.client;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.security.authentication.AuthenticationProvider;
@ -55,6 +56,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
@ -488,6 +490,10 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider =
new OidcAuthorizationCodeAuthenticationProvider(accessTokenResponseClient, oidcUserService);
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = this.getJwtDecoderFactoryBean();
if (jwtDecoderFactory != null) {
oidcAuthorizationCodeAuthenticationProvider.setJwtDecoderFactory(jwtDecoderFactory);
}
if (userAuthoritiesMapper != null) {
oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper);
}
@ -541,6 +547,19 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
return new AntPathRequestMatcher(loginProcessingUrl);
}
@SuppressWarnings("unchecked")
private JwtDecoderFactory<ClientRegistration> getJwtDecoderFactoryBean() {
ResolvableType type = ResolvableType.forClassWithGenerics(JwtDecoderFactory.class, ClientRegistration.class);
String[] names = this.getBuilder().getSharedObject(ApplicationContext.class).getBeanNamesForType(type);
if (names.length > 1) {
throw new NoUniqueBeanDefinitionException(type, names);
}
if (names.length == 1) {
return (JwtDecoderFactory<ClientRegistration>) this.getBuilder().getSharedObject(ApplicationContext.class).getBean(names[0]);
}
return null;
}
private GrantedAuthoritiesMapper getGrantedAuthoritiesMapper() {
GrantedAuthoritiesMapper grantedAuthoritiesMapper =
this.getBuilder().getSharedObject(GrantedAuthoritiesMapper.class);

View File

@ -16,28 +16,6 @@
package org.springframework.security.config.web.server;
import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry;
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.match;
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.notMatch;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.UUID;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.core.Ordered;
@ -55,6 +33,8 @@ import org.springframework.security.authorization.AuthorizationDecision;
import org.springframework.security.authorization.ReactiveAuthorizationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
@ -80,6 +60,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter;
import org.springframework.security.oauth2.server.resource.authentication.JwtReactiveAuthenticationManager;
import org.springframework.security.oauth2.server.resource.authentication.ReactiveJwtAuthenticationConverterAdapter;
@ -92,6 +73,7 @@ import org.springframework.security.web.server.MatcherSecurityWebFilterChain;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.ServerAuthenticationEntryPoint;
import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilter;
import org.springframework.security.web.server.authentication.AuthenticationWebFilter;
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint;
@ -159,9 +141,27 @@ import org.springframework.web.cors.reactive.DefaultCorsProcessor;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilter;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Function;
import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry;
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.match;
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.notMatch;
/**
* A {@link ServerHttpSecurity} is similar to Spring Security's {@code HttpSecurity} but for WebFlux.
@ -618,7 +618,14 @@ public class ServerHttpSecurity {
boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent(
"org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader());
if (oidcAuthenticationProviderEnabled) {
OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager(client, getOidcUserService());
OidcAuthorizationCodeReactiveAuthenticationManager oidc =
new OidcAuthorizationCodeReactiveAuthenticationManager(client, getOidcUserService());
ResolvableType type = ResolvableType.forClassWithGenerics(
ReactiveJwtDecoderFactory.class, ClientRegistration.class);
ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = getBeanOrNull(type);
if (jwtDecoderFactory != null) {
oidc.setJwtDecoderFactory(jwtDecoderFactory);
}
result = new DelegatingReactiveAuthenticationManager(oidc, result);
}
return result;

View File

@ -19,11 +19,12 @@ import org.apache.http.HttpHeaders;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.PropertyAccessorFactory;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
@ -50,7 +51,6 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@ -66,6 +66,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
@ -81,6 +82,7 @@ import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -369,8 +371,7 @@ public class OAuth2LoginConfigurerTests {
@Test
public void oidcLogin() throws Exception {
// setup application context
loadConfig(OAuth2LoginConfig.class);
registerJwtDecoder();
loadConfig(OAuth2LoginConfig.class, JwtDecoderFactoryConfig.class);
// setup authorization request
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid");
@ -396,8 +397,7 @@ public class OAuth2LoginConfigurerTests {
@Test
public void oidcLoginCustomWithConfigurer() throws Exception {
// setup application context
loadConfig(OAuth2LoginConfigCustomWithConfigurer.class);
registerJwtDecoder();
loadConfig(OAuth2LoginConfigCustomWithConfigurer.class, JwtDecoderFactoryConfig.class);
// setup authorization request
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid");
@ -423,8 +423,7 @@ public class OAuth2LoginConfigurerTests {
@Test
public void oidcLoginCustomWithBeanRegistration() throws Exception {
// setup application context
loadConfig(OAuth2LoginConfigCustomWithBeanRegistration.class);
registerJwtDecoder();
loadConfig(OAuth2LoginConfigCustomWithBeanRegistration.class, JwtDecoderFactoryConfig.class);
// setup authorization request
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid");
@ -447,6 +446,15 @@ public class OAuth2LoginConfigurerTests {
assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OIDC_USER");
}
@Test
public void oidcLoginCustomWithNoUniqueJwtDecoderFactory() {
assertThatThrownBy(() -> loadConfig(OAuth2LoginConfig.class, NoUniqueJwtDecoderFactoryConfig.class))
.hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class)
.hasMessageContaining("No qualifying bean of type " +
"'org.springframework.security.oauth2.jwt.JwtDecoderFactory<org.springframework.security.oauth2.client.registration.ClientRegistration>' " +
"available: expected single matching bean but found 2: jwtDecoderFactory1,jwtDecoderFactory2");
}
private void loadConfig(Class<?>... configs) {
AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext();
applicationContext.register(configs);
@ -455,25 +463,6 @@ public class OAuth2LoginConfigurerTests {
this.context = applicationContext;
}
private void registerJwtDecoder() {
JwtDecoder decoder = token -> {
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.SUB, "sub123");
claims.put(IdTokenClaimNames.ISS, "http://localhost/iss");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d"));
claims.put(IdTokenClaimNames.AZP, "clientId");
return new Jwt("token123", Instant.now(), Instant.now().plusSeconds(3600),
Collections.singletonMap("header1", "value1"), claims);
};
this.springSecurityFilterChain.getFilters("/login/oauth2/code/google").stream()
.filter(OAuth2LoginAuthenticationFilter.class::isInstance)
.findFirst()
.ifPresent(filter -> PropertyAccessorFactory.forDirectFieldAccess(filter)
.setPropertyValue(
"authenticationManager.providers[2].jwtDecoders['google']",
decoder));
}
private OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(String... scopes) {
return this.createOAuth2AuthorizationRequest(GOOGLE_CLIENT_REGISTRATION, scopes);
}
@ -632,6 +621,42 @@ public class OAuth2LoginConfigurerTests {
}
}
@Configuration
static class JwtDecoderFactoryConfig {
@Bean
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory() {
return clientRegistration -> getJwtDecoder();
}
private static JwtDecoder getJwtDecoder() {
return token -> {
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.SUB, "sub123");
claims.put(IdTokenClaimNames.ISS, "http://localhost/iss");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d"));
claims.put(IdTokenClaimNames.AZP, "clientId");
return new Jwt("token123", Instant.now(), Instant.now().plusSeconds(3600),
Collections.singletonMap("header1", "value1"), claims);
};
}
}
@Configuration
static class NoUniqueJwtDecoderFactoryConfig {
@Bean
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory1() {
return clientRegistration -> JwtDecoderFactoryConfig.getJwtDecoder();
}
@Bean
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory2() {
return clientRegistration -> JwtDecoderFactoryConfig.getJwtDecoder();
}
}
private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> createOauth2AccessTokenResponseClient() {
return request -> {
Map<String, Object> additionalParameters = new HashMap<>();

View File

@ -16,12 +16,6 @@
package org.springframework.security.config.web.server;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import org.junit.Rule;
import org.junit.Test;
import org.openqa.selenium.WebDriver;
@ -34,15 +28,29 @@ import org.springframework.security.config.annotation.web.reactive.EnableWebFlux
import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeReactiveAuthenticationManager;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.TestOAuth2Users;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.WebFilterChainProxy;
@ -51,9 +59,17 @@ import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
/**
* @author Rob Winch
* @since 5.1
@ -72,6 +88,12 @@ public class OAuth2LoginTests {
.clientSecret("secret")
.build();
private static ClientRegistration google = CommonOAuth2Provider.GOOGLE
.getBuilder("google")
.clientId("client")
.clientSecret("secret")
.build();
@Test
public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() {
this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire();
@ -97,11 +119,6 @@ public class OAuth2LoginTests {
static class OAuth2LoginWithMulitpleClientRegistrations {
@Bean
InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() {
ClientRegistration google = CommonOAuth2Provider.GOOGLE
.getBuilder("google")
.clientId("client")
.clientSecret("secret")
.build();
return new InMemoryReactiveClientRegistrationRepository(github, google);
}
}
@ -182,6 +199,107 @@ public class OAuth2LoginTests {
}
}
@Test
public void oauth2LoginWhenCustomJwtDecoderFactoryThenUsed() {
this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class,
OAuth2LoginWithJwtDecoderFactoryBeanConfig.class).autowire();
WebTestClient webTestClient = WebTestClientBuilder
.bindToWebFilters(this.springSecurity)
.build();
OAuth2LoginWithJwtDecoderFactoryBeanConfig config = this.spring.getContext()
.getBean(OAuth2LoginWithJwtDecoderFactoryBeanConfig.class);
OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success();
OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid");
OAuth2AuthorizationCodeAuthenticationToken token = new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken);
ServerAuthenticationConverter converter = config.authenticationConverter;
when(converter.convert(any())).thenReturn(Mono.just(token));
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
.tokenType(accessToken.getTokenType())
.scopes(accessToken.getScopes())
.additionalParameters(additionalParameters)
.build();
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> tokenResponseClient = config.tokenResponseClient;
when(tokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
OidcUser user = TestOidcUsers.create();
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = config.userService;
when(userService.loadUser(any())).thenReturn(Mono.just(user));
webTestClient.get()
.uri("/login/oauth2/code/google")
.exchange()
.expectStatus().is3xxRedirection();
verify(config.jwtDecoderFactory).createDecoder(any());
}
@Configuration
static class OAuth2LoginWithJwtDecoderFactoryBeanConfig {
ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class);
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> tokenResponseClient =
mock(ReactiveOAuth2AccessTokenResponseClient.class);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = spy(new JwtDecoderFactory());
@Bean
public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
// @formatter:off
http
.authorizeExchange()
.anyExchange().authenticated()
.and()
.oauth2Login()
.authenticationConverter(authenticationConverter)
.authenticationManager(authenticationManager());
return http.build();
// @formatter:on
}
private ReactiveAuthenticationManager authenticationManager() {
OidcAuthorizationCodeReactiveAuthenticationManager oidc =
new OidcAuthorizationCodeReactiveAuthenticationManager(tokenResponseClient, userService);
oidc.setJwtDecoderFactory(jwtDecoderFactory());
return oidc;
}
@Bean
public ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory() {
return jwtDecoderFactory;
}
private static class JwtDecoderFactory implements ReactiveJwtDecoderFactory<ClientRegistration> {
@Override
public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
return getJwtDecoder();
}
private ReactiveJwtDecoder getJwtDecoder() {
return token -> {
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.SUB, "subject");
claims.put(IdTokenClaimNames.ISS, "http://localhost/issuer");
claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client"));
claims.put(IdTokenClaimNames.AZP, "client");
Jwt jwt = new Jwt("id-token", Instant.now(), Instant.now().plusSeconds(3600),
Collections.singletonMap("header1", "value1"), claims);
return Mono.just(jwt);
};
}
}
}
static class GitHubWebFilter implements WebFilter {
@Override

View File

@ -15,10 +15,6 @@
*/
package org.springframework.security.oauth2.client.oidc.authentication;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
@ -43,10 +39,15 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import static org.springframework.security.oauth2.jwt.JwtProcessors.withJwkSetUri;
/**
@ -80,7 +81,7 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier";
private final OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
private final OAuth2UserService<OidcUserRequest, OidcUser> userService;
private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
private JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new DefaultJwtDecoderFactory();
private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
/**
@ -174,6 +175,18 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
return authenticationResult;
}
/**
* Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature verification.
* The factory returns a {@link JwtDecoder} associated to the provided {@link ClientRegistration}.
*
* @since 5.2
* @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature verification
*/
public final void setJwtDecoderFactory(JwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
this.jwtDecoderFactory = jwtDecoderFactory;
}
/**
* Sets the {@link GrantedAuthoritiesMapper} used for mapping {@link OidcUser#getAuthorities()}}
* to a new set of authorities which will be associated to the {@link OAuth2LoginAuthenticationToken}.
@ -191,30 +204,32 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
}
private OidcIdToken createOidcToken(ClientRegistration clientRegistration, OAuth2AccessTokenResponse accessTokenResponse) {
JwtDecoder jwtDecoder = getJwtDecoder(clientRegistration);
Jwt jwt = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(
OidcParameterNames.ID_TOKEN));
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
Jwt jwt = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
OidcIdToken idToken = new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims());
OidcTokenValidator.validateIdToken(idToken, clientRegistration);
return idToken;
}
private JwtDecoder getJwtDecoder(ClientRegistration clientRegistration) {
JwtDecoder jwtDecoder = this.jwtDecoders.get(clientRegistration.getRegistrationId());
if (jwtDecoder == null) {
private static class DefaultJwtDecoderFactory implements JwtDecoderFactory<ClientRegistration> {
private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
@Override
public JwtDecoder createDecoder(ClientRegistration clientRegistration) {
return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> {
if (!StringUtils.hasText(clientRegistration.getProviderDetails().getJwkSetUri())) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() + "'. Check to ensure you have configured the JwkSet URI.",
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured the JwkSet URI.",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
jwtDecoder = new NimbusJwtDecoder(withJwkSetUri(jwkSetUri).build());
this.jwtDecoders.put(clientRegistration.getRegistrationId(), jwtDecoder);
return new NimbusJwtDecoder(withJwkSetUri(jwkSetUri).build());
});
}
return jwtDecoder;
}
}

View File

@ -39,6 +39,7 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Mono;
@ -46,7 +47,6 @@ import reactor.core.publisher.Mono;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
/**
* An implementation of an {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth 2.0 Login,
@ -86,7 +86,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
private Function<ClientRegistration, ReactiveJwtDecoder> decoderFactory = new DefaultDecoderFactory();
private ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new DefaultJwtDecoderFactory();
public OidcAuthorizationCodeReactiveAuthenticationManager(
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient,
@ -143,13 +143,15 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
}
/**
* Provides a way to customize the {@link ReactiveJwtDecoder} given a {@link ClientRegistration}
* @param decoderFactory the {@link Function} used to create {@link ReactiveJwtDecoder} instance. Cannot be null.
* Sets the {@link ReactiveJwtDecoderFactory} used for {@link OidcIdToken} signature verification.
* The factory returns a {@link ReactiveJwtDecoder} associated to the provided {@link ClientRegistration}.
*
* @since 5.2
* @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} used for {@link OidcIdToken} signature verification
*/
void setDecoderFactory(
Function<ClientRegistration, ReactiveJwtDecoder> decoderFactory) {
Assert.notNull(decoderFactory, "decoderFactory cannot be null");
this.decoderFactory = decoderFactory;
public final void setJwtDecoderFactory(ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
this.jwtDecoderFactory = jwtDecoderFactory;
}
private Mono<OAuth2LoginAuthenticationToken> authenticationResult(OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
@ -183,33 +185,31 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
}
private Mono<OidcIdToken> createOidcToken(ClientRegistration clientRegistration, OAuth2AccessTokenResponse accessTokenResponse) {
ReactiveJwtDecoder jwtDecoder = this.decoderFactory.apply(clientRegistration);
ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
String rawIdToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN);
return jwtDecoder.decode(rawIdToken)
.map(jwt -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()))
.doOnNext(idToken -> OidcTokenValidator.validateIdToken(idToken, clientRegistration));
}
private static class DefaultDecoderFactory implements Function<ClientRegistration, ReactiveJwtDecoder> {
private static class DefaultJwtDecoderFactory implements ReactiveJwtDecoderFactory<ClientRegistration> {
private final Map<String, ReactiveJwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
@Override
public ReactiveJwtDecoder apply(ClientRegistration clientRegistration) {
ReactiveJwtDecoder jwtDecoder = this.jwtDecoders.get(clientRegistration.getRegistrationId());
if (jwtDecoder == null) {
public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> {
if (!StringUtils.hasText(clientRegistration.getProviderDetails().getJwkSetUri())) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() + "'. Check to ensure you have configured the JwkSet URI.",
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured the JwkSet URI.",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
jwtDecoder = new NimbusReactiveJwtDecoder(clientRegistration.getProviderDetails().getJwkSetUri());
this.jwtDecoders.put(clientRegistration.getRegistrationId(), jwtDecoder);
}
return jwtDecoder;
return new NimbusReactiveJwtDecoder(clientRegistration.getProviderDetails().getJwkSetUri());
});
}
}
}

View File

@ -15,22 +15,12 @@
*/
package org.springframework.security.oauth2.client.oidc.authentication;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
@ -52,13 +42,19 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.test.util.ReflectionTestUtils;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
@ -112,6 +108,12 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null);
}
@Test
public void setJwtDecoderFactoryWhenNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
this.authenticationProvider.setJwtDecoderFactory(null);
}
@Test
public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
@ -428,8 +430,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
JwtDecoder jwtDecoder = mock(JwtDecoder.class);
when(jwtDecoder.decode(anyString())).thenReturn(idToken);
ReflectionTestUtils.setField(this.authenticationProvider,
"jwtDecoders", Collections.singletonMap("registration-id", jwtDecoder));
this.authenticationProvider.setJwtDecoderFactory(registration -> jwtDecoder);
}
private OAuth2AccessTokenResponse accessTokenSuccessResponse() {

View File

@ -53,9 +53,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@ -105,6 +103,12 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setJwtDecoderFactoryWhenNullThenIllegalArgumentException() {
assertThatThrownBy(() -> this.manager.setJwtDecoderFactory(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void authenticateWhenNoSubscriptionThenDoesNothing() {
// we didn't do anything because it should cause a ClassCastException (as verified below)
@ -157,7 +161,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
when(this.userService.loadUser(any())).thenReturn(Mono.empty());
when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
this.manager.setDecoderFactory(c -> this.jwtDecoder);
this.manager.setJwtDecoderFactory(c -> this.jwtDecoder);
assertThat(this.manager.authenticate(loginToken()).block()).isNull();
}
@ -180,7 +184,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken);
when(this.userService.loadUser(any())).thenReturn(Mono.just(user));
when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
this.manager.setDecoderFactory(c -> this.jwtDecoder);
this.manager.setJwtDecoderFactory(c -> this.jwtDecoder);
OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block();
@ -209,7 +213,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken);
when(this.userService.loadUser(any())).thenReturn(Mono.just(user));
when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
this.manager.setDecoderFactory(c -> this.jwtDecoder);
this.manager.setJwtDecoderFactory(c -> this.jwtDecoder);
OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block();
@ -245,7 +249,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));
when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
this.manager.setDecoderFactory(c -> this.jwtDecoder);
this.manager.setJwtDecoderFactory(c -> this.jwtDecoder);
this.manager.authenticate(loginToken()).block();

View File

@ -32,6 +32,7 @@ public class TestOAuth2AuthorizationRequests {
return OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/login/oauth/authorize")
.clientId(clientId)
.scope("openid")
.redirectUri("https://example.com/authorize/oauth2/code/registration-id")
.state("state")
.additionalParameters(additionalParameters);

View File

@ -0,0 +1,47 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.core.oidc.user;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* @author Joe Grandja
*/
public class TestOidcUsers {
public static DefaultOidcUser create() {
List<GrantedAuthority> roles = AuthorityUtils.createAuthorityList("ROLE_USER");
return new DefaultOidcUser(roles, idToken());
}
private static OidcIdToken idToken() {
Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.SUB, "subject");
claims.put(IdTokenClaimNames.ISS, "http://localhost/issuer");
claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client"));
claims.put(IdTokenClaimNames.AZP, "client");
return new OidcIdToken("id-token", Instant.now(), Instant.now().plusSeconds(3600), claims);
}
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.jwt;
/**
* A factory for {@link JwtDecoder}(s).
* This factory should be supplied with a type that provides
* contextual information used to create a specific {@code JwtDecoder}.
*
* @author Joe Grandja
* @since 5.2
* @see JwtDecoder
*
* @param <C> The type that provides contextual information used to create a specific {@code JwtDecoder}.
*/
public interface JwtDecoderFactory<C> {
/**
* Creates a {@code JwtDecoder} using the supplied "contextual" type.
*
* @param context the type that provides contextual information
* @return a {@link JwtDecoder}
*/
JwtDecoder createDecoder(C context);
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.jwt;
/**
* A factory for {@link ReactiveJwtDecoder}(s).
* This factory should be supplied with a type that provides
* contextual information used to create a specific {@code ReactiveJwtDecoder}.
*
* @author Joe Grandja
* @since 5.2
* @see ReactiveJwtDecoder
*
* @param <C> The type that provides contextual information used to create a specific {@code ReactiveJwtDecoder}.
*/
public interface ReactiveJwtDecoderFactory<C> {
/**
* Creates a {@code ReactiveJwtDecoder} using the supplied "contextual" type.
*
* @param context the type that provides contextual information
* @return a {@link ReactiveJwtDecoder}
*/
ReactiveJwtDecoder createDecoder(C context);
}