OAuth2LoginReactiveAuthenticationManager uses OAuth2AuthorizationCodeReactiveAuthenticationManager

Issue: gh-5620
This commit is contained in:
Rob Winch 2018-08-17 22:05:33 -05:00
parent 8b67154e77
commit d0ebe47cd5
10 changed files with 76 additions and 65 deletions

View File

@ -40,9 +40,9 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2LoginAuthenticationTokenConverter;
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
@ -491,7 +491,7 @@ public class ServerHttpSecurity {
AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository); AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository);
authenticationFilter.setRequiresAuthenticationMatcher(new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}")); authenticationFilter.setRequiresAuthenticationMatcher(new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}"));
authenticationFilter.setServerAuthenticationConverter(new ServerOAuth2LoginAuthenticationTokenConverter(clientRegistrationRepository)); authenticationFilter.setServerAuthenticationConverter(new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository));
RedirectServerAuthenticationSuccessHandler redirectHandler = new RedirectServerAuthenticationSuccessHandler(); RedirectServerAuthenticationSuccessHandler redirectHandler = new RedirectServerAuthenticationSuccessHandler();

View File

@ -25,6 +25,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExch
import org.springframework.util.Assert; import org.springframework.util.Assert;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
/** /**
* An {@link AbstractAuthenticationToken} for the OAuth 2.0 Authorization Code Grant. * An {@link AbstractAuthenticationToken} for the OAuth 2.0 Authorization Code Grant.
@ -39,6 +41,7 @@ import java.util.Collections;
*/ */
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken { public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;
private Map<String, Object> additionalParameters = new HashMap<>();
private ClientRegistration clientRegistration; private ClientRegistration clientRegistration;
private OAuth2AuthorizationExchange authorizationExchange; private OAuth2AuthorizationExchange authorizationExchange;
private OAuth2AccessToken accessToken; private OAuth2AccessToken accessToken;
@ -86,11 +89,17 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
OAuth2AuthorizationExchange authorizationExchange, OAuth2AuthorizationExchange authorizationExchange,
OAuth2AccessToken accessToken, OAuth2AccessToken accessToken,
@Nullable OAuth2RefreshToken refreshToken) { @Nullable OAuth2RefreshToken refreshToken) {
this(clientRegistration, authorizationExchange, accessToken, refreshToken, Collections.emptyMap());
}
public OAuth2AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration, OAuth2AuthorizationExchange authorizationExchange, OAuth2AccessToken accessToken, OAuth2RefreshToken refreshToken,
Map<String, Object> additionalParameters) {
this(clientRegistration, authorizationExchange); this(clientRegistration, authorizationExchange);
Assert.notNull(accessToken, "accessToken cannot be null"); Assert.notNull(accessToken, "accessToken cannot be null");
this.accessToken = accessToken; this.accessToken = accessToken;
this.refreshToken = refreshToken; this.refreshToken = refreshToken;
this.setAuthenticated(true); this.setAuthenticated(true);
this.additionalParameters.putAll(additionalParameters);
} }
@Override @Override
@ -140,4 +149,13 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
public @Nullable OAuth2RefreshToken getRefreshToken() { public @Nullable OAuth2RefreshToken getRefreshToken() {
return this.refreshToken; return this.refreshToken;
} }
/**
* Returns the additional parameters
*
* @return the additional parameters
*/
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}
} }

View File

@ -87,7 +87,7 @@ public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements
OAuth2AuthorizationExchange exchange = token.getAuthorizationExchange(); OAuth2AuthorizationExchange exchange = token.getAuthorizationExchange();
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
OAuth2RefreshToken refreshToken = accessTokenResponse.getRefreshToken(); OAuth2RefreshToken refreshToken = accessTokenResponse.getRefreshToken();
return new OAuth2AuthorizationCodeAuthenticationToken(registration, exchange, accessToken, refreshToken); return new OAuth2AuthorizationCodeAuthenticationToken(registration, exchange, accessToken, refreshToken, accessTokenResponse.getAdditionalParameters());
}; };
} }
} }

View File

@ -15,26 +15,21 @@
*/ */
package org.springframework.security.oauth2.client.authentication; package org.springframework.security.oauth2.client.authentication;
import java.util.Collection;
import java.util.Map;
import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.util.Collection;
import java.util.Map;
/** /**
* An implementation of an {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth 2.0 Login, * An implementation of an {@link org.springframework.security.authentication.AuthenticationProvider} for OAuth 2.0 Login,
@ -62,7 +57,7 @@ import reactor.core.publisher.Mono;
*/ */
public class OAuth2LoginReactiveAuthenticationManager implements public class OAuth2LoginReactiveAuthenticationManager implements
ReactiveAuthenticationManager { ReactiveAuthenticationManager {
private final ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient; private final ReactiveAuthenticationManager authorizationCodeManager;
private final ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> userService; private final ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> userService;
@ -73,18 +68,18 @@ public class OAuth2LoginReactiveAuthenticationManager implements
ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> userService) { ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> userService) {
Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
Assert.notNull(userService, "userService cannot be null"); Assert.notNull(userService, "userService cannot be null");
this.accessTokenResponseClient = accessTokenResponseClient; this.authorizationCodeManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager(accessTokenResponseClient);
this.userService = userService; this.userService = userService;
} }
@Override @Override
public Mono<Authentication> authenticate(Authentication authentication) { public Mono<Authentication> authenticate(Authentication authentication) {
return Mono.defer(() -> { return Mono.defer(() -> {
OAuth2LoginAuthenticationToken authorizationCodeAuthentication = (OAuth2LoginAuthenticationToken) authentication; OAuth2AuthorizationCodeAuthenticationToken token = (OAuth2AuthorizationCodeAuthenticationToken) authentication;
// Section 3.1.2.1 Authentication Request - http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest // Section 3.1.2.1 Authentication Request - http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
// scope REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. // scope REQUIRED. OpenID Connect requests MUST contain the "openid" scope value.
if (authorizationCodeAuthentication.getAuthorizationExchange() if (token.getAuthorizationExchange()
.getAuthorizationRequest().getScopes().contains("openid")) { .getAuthorizationRequest().getScopes().contains("openid")) {
// This is an OpenID Connect Authentication Request so return null // This is an OpenID Connect Authentication Request so return null
// and let OidcAuthorizationCodeReactiveAuthenticationManager handle it instead once one is created // and let OidcAuthorizationCodeReactiveAuthenticationManager handle it instead once one is created
@ -92,34 +87,28 @@ public class OAuth2LoginReactiveAuthenticationManager implements
// return Mono.empty(); // return Mono.empty();
} }
OAuth2AuthorizationExchangeValidator.validate(authorizationCodeAuthentication.getAuthorizationExchange()); return this.authorizationCodeManager.authenticate(token)
.cast(OAuth2AuthorizationCodeAuthenticationToken.class)
OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( .flatMap(this::onSuccess);
authorizationCodeAuthentication.getClientRegistration(),
authorizationCodeAuthentication.getAuthorizationExchange());
return this.accessTokenResponseClient.getTokenResponse(authzRequest)
.flatMap(accessTokenResponse -> authenticationResult(authorizationCodeAuthentication, accessTokenResponse));
}); });
} }
private Mono<OAuth2LoginAuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { private Mono<OAuth2LoginAuthenticationToken> onSuccess(OAuth2AuthorizationCodeAuthenticationToken authentication) {
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); OAuth2AccessToken accessToken = authentication.getAccessToken();
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters(); Map<String, Object> additionalParameters = authentication.getAdditionalParameters();
OAuth2UserRequest userRequest = new OAuth2UserRequest( OAuth2UserRequest userRequest = new OAuth2UserRequest(authentication.getClientRegistration(), accessToken, additionalParameters);
authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters);
return this.userService.loadUser(userRequest) return this.userService.loadUser(userRequest)
.map(oauth2User -> { .map(oauth2User -> {
Collection<? extends GrantedAuthority> mappedAuthorities = Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken(
authorizationCodeAuthentication.getClientRegistration(), authentication.getClientRegistration(),
authorizationCodeAuthentication.getAuthorizationExchange(), authentication.getAuthorizationExchange(),
oauth2User, oauth2User,
mappedAuthorities, mappedAuthorities,
accessToken, accessToken,
accessTokenResponse.getRefreshToken()); authentication.getRefreshToken());
return authenticationResult; return authenticationResult;
}); });
} }

View File

@ -19,6 +19,7 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
@ -98,7 +99,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
@Override @Override
public Mono<Authentication> authenticate(Authentication authentication) { public Mono<Authentication> authenticate(Authentication authentication) {
return Mono.defer(() -> { return Mono.defer(() -> {
OAuth2LoginAuthenticationToken authorizationCodeAuthentication = (OAuth2LoginAuthenticationToken) authentication; OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = (OAuth2AuthorizationCodeAuthenticationToken) authentication;
// Section 3.1.2.1 Authentication Request - http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest // Section 3.1.2.1 Authentication Request - http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
// scope REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. // scope REQUIRED. OpenID Connect requests MUST contain the "openid" scope value.
@ -149,7 +150,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
this.decoderFactory = decoderFactory; this.decoderFactory = decoderFactory;
} }
private Mono<OAuth2LoginAuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { private Mono<OAuth2LoginAuthenticationToken> authenticationResult(OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters(); Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();

View File

@ -17,6 +17,7 @@
package org.springframework.security.oauth2.client.web.server; package org.springframework.security.oauth2.client.web.server;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@ -30,7 +31,6 @@ import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
/** /**
@ -40,7 +40,7 @@ import reactor.core.publisher.Mono;
* @since 5.1 * @since 5.1
* @see org.springframework.security.web.server.authentication.AuthenticationWebFilter#setServerAuthenticationConverter(ServerAuthenticationConverter) * @see org.springframework.security.web.server.authentication.AuthenticationWebFilter#setServerAuthenticationConverter(ServerAuthenticationConverter)
*/ */
public class ServerOAuth2LoginAuthenticationTokenConverter public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter
implements ServerAuthenticationConverter { implements ServerAuthenticationConverter {
static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
@ -52,7 +52,7 @@ public class ServerOAuth2LoginAuthenticationTokenConverter
private final ReactiveClientRegistrationRepository clientRegistrationRepository; private final ReactiveClientRegistrationRepository clientRegistrationRepository;
public ServerOAuth2LoginAuthenticationTokenConverter( public ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(
ReactiveClientRegistrationRepository clientRegistrationRepository) { ReactiveClientRegistrationRepository clientRegistrationRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository; this.clientRegistrationRepository = clientRegistrationRepository;
@ -83,7 +83,7 @@ public class ServerOAuth2LoginAuthenticationTokenConverter
}); });
} }
private Mono<OAuth2LoginAuthenticationToken> authenticationRequest(ServerWebExchange exchange, OAuth2AuthorizationRequest authorizationRequest) { private Mono<OAuth2AuthorizationCodeAuthenticationToken> authenticationRequest(ServerWebExchange exchange, OAuth2AuthorizationRequest authorizationRequest) {
return Mono.just(authorizationRequest) return Mono.just(authorizationRequest)
.map(OAuth2AuthorizationRequest::getAdditionalParameters) .map(OAuth2AuthorizationRequest::getAdditionalParameters)
.flatMap(additionalParams -> { .flatMap(additionalParams -> {
@ -96,7 +96,7 @@ public class ServerOAuth2LoginAuthenticationTokenConverter
.switchIfEmpty(oauth2AuthenticationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE)) .switchIfEmpty(oauth2AuthenticationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE))
.map(clientRegistration -> { .map(clientRegistration -> {
OAuth2AuthorizationResponse authorizationResponse = convertResponse(exchange); OAuth2AuthorizationResponse authorizationResponse = convertResponse(exchange);
OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken( OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken(
clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse));
return authenticationRequest; return authenticationRequest;
}); });

View File

@ -178,7 +178,7 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters()); .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
} }
private OAuth2LoginAuthenticationToken loginToken() { private OAuth2AuthorizationCodeAuthenticationToken loginToken() {
ClientRegistration clientRegistration = this.registration.build(); ClientRegistration clientRegistration = this.registration.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
.authorizationCode() .authorizationCode()
@ -193,6 +193,6 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
.build(); .build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
authorizationResponse); authorizationResponse);
return new OAuth2LoginAuthenticationToken(clientRegistration, authorizationExchange); return new OAuth2AuthorizationCodeAuthenticationToken(clientRegistration, authorizationExchange);
} }
} }

View File

@ -24,7 +24,7 @@ import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
@ -182,7 +182,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
this.manager.setDecoderFactory(c -> this.jwtDecoder); this.manager.setDecoderFactory(c -> this.jwtDecoder);
OAuth2AuthenticationToken result = (OAuth2AuthenticationToken) this.manager.authenticate(loginToken()).block(); OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block();
assertThat(result.getPrincipal()).isEqualTo(user); assertThat(result.getPrincipal()).isEqualTo(user);
assertThat(result.getAuthorities()).containsOnlyElementsOf(user.getAuthorities()); assertThat(result.getAuthorities()).containsOnlyElementsOf(user.getAuthorities());
@ -192,6 +192,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
// gh-5368 // gh-5368
@Test @Test
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() { public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
ClientRegistration clientRegistration = this.registration.build();
Map<String, Object> additionalParameters = new HashMap<>(); Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()); additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue());
additionalParameters.put("param1", "value1"); additionalParameters.put("param1", "value1");
@ -204,7 +205,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
claims.put(IdTokenClaimNames.SUB, "rob"); claims.put(IdTokenClaimNames.SUB, "rob");
claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId")); claims.put(IdTokenClaimNames.AUD, Arrays.asList(clientRegistration.getClientId()));
Instant issuedAt = Instant.now(); Instant issuedAt = Instant.now();
Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600); Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
Jwt idToken = new Jwt("id-token", issuedAt, expiresAt, claims, claims); Jwt idToken = new Jwt("id-token", issuedAt, expiresAt, claims, claims);
@ -222,7 +223,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters()); .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
} }
private OAuth2LoginAuthenticationToken loginToken() { private OAuth2AuthorizationCodeAuthenticationToken loginToken() {
ClientRegistration clientRegistration = this.registration.build(); ClientRegistration clientRegistration = this.registration.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
.authorizationCode() .authorizationCode()
@ -237,6 +238,6 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
.build(); .build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
authorizationResponse); authorizationResponse);
return new OAuth2LoginAuthenticationToken(clientRegistration, authorizationExchange); return new OAuth2AuthorizationCodeAuthenticationToken(clientRegistration, authorizationExchange);
} }
} }

View File

@ -16,13 +16,6 @@
package org.springframework.security.oauth2.client.web.server; package org.springframework.security.oauth2.client.web.server;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import java.util.Collections;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -30,7 +23,7 @@ import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -39,15 +32,21 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.util.Collections;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
/** /**
* @author Rob Winch * @author Rob Winch
* @since 5.1 * @since 5.1
*/ */
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class ServerOAuth2LoginAuthenticationTokenConverterTest { public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest {
@Mock @Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository; private ReactiveClientRegistrationRepository clientRegistrationRepository;
@ -79,11 +78,11 @@ public class ServerOAuth2LoginAuthenticationTokenConverterTest {
private final MockServerHttpRequest.BaseBuilder<?> request = MockServerHttpRequest.get("/"); private final MockServerHttpRequest.BaseBuilder<?> request = MockServerHttpRequest.get("/");
private ServerOAuth2LoginAuthenticationTokenConverter converter; private ServerOAuth2AuthorizationCodeAuthenticationTokenConverter converter;
@Before @Before
public void setup() { public void setup() {
this.converter = new ServerOAuth2LoginAuthenticationTokenConverter(this.clientRegistrationRepository); this.converter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(this.clientRegistrationRepository);
this.converter.setAuthorizationRequestRepository(this.authorizationRequestRepository); this.converter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
} }
@ -102,8 +101,7 @@ public class ServerOAuth2LoginAuthenticationTokenConverterTest {
assertThatThrownBy(() -> applyConverter()) assertThatThrownBy(() -> applyConverter())
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.hasMessageContaining( .hasMessageContaining(ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE);
ServerOAuth2LoginAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE);
} }
@Test @Test
@ -113,8 +111,7 @@ public class ServerOAuth2LoginAuthenticationTokenConverterTest {
assertThatThrownBy(() -> applyConverter()) assertThatThrownBy(() -> applyConverter())
.isInstanceOf(OAuth2AuthenticationException.class) .isInstanceOf(OAuth2AuthenticationException.class)
.hasMessageContaining( .hasMessageContaining(ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE);
ServerOAuth2LoginAuthenticationTokenConverter.CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE);
} }
@Test @Test
@ -133,7 +130,7 @@ public class ServerOAuth2LoginAuthenticationTokenConverterTest {
when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build())); when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build()));
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.clientRegistration)); when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.clientRegistration));
OAuth2LoginAuthenticationToken result = applyConverter(); OAuth2AuthorizationCodeAuthenticationToken result = applyConverter();
OAuth2AuthorizationResponse exchange = result OAuth2AuthorizationResponse exchange = result
.getAuthorizationExchange().getAuthorizationResponse(); .getAuthorizationExchange().getAuthorizationResponse();
@ -141,8 +138,8 @@ public class ServerOAuth2LoginAuthenticationTokenConverterTest {
assertThat(exchange.getCode()).isEqualTo("code"); assertThat(exchange.getCode()).isEqualTo("code");
} }
private OAuth2LoginAuthenticationToken applyConverter() { private OAuth2AuthorizationCodeAuthenticationToken applyConverter() {
MockServerWebExchange exchange = MockServerWebExchange.from(this.request); MockServerWebExchange exchange = MockServerWebExchange.from(this.request);
return (OAuth2LoginAuthenticationToken) this.converter.convert(exchange).block(); return (OAuth2AuthorizationCodeAuthenticationToken) this.converter.convert(exchange).block();
} }
} }

View File

@ -112,7 +112,12 @@ class OAuth2AccessTokenResponseBodyExtractor
Map<String, Object> additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters()); Map<String, Object> additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters());
return OAuth2AccessTokenResponse.withToken(accessToken.getValue()).tokenType(accessTokenType).expiresIn(expiresIn).scopes(scopes) return OAuth2AccessTokenResponse.withToken(accessToken.getValue())
.refreshToken(refreshToken).additionalParameters(additionalParameters).build(); .tokenType(accessTokenType)
.expiresIn(expiresIn)
.scopes(scopes)
.refreshToken(refreshToken)
.additionalParameters(additionalParameters)
.build();
} }
} }