Move logic from AuthorizationCodeAuthenticationFilter to OAuth2UserAuthenticationProvider
This commit is contained in:
parent
18df9a869e
commit
df474e04d8
|
@ -20,6 +20,9 @@ import org.springframework.security.core.Authentication;
|
|||
import org.springframework.security.core.AuthenticationException;
|
||||
import org.springframework.security.core.GrantedAuthority;
|
||||
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy;
|
||||
import org.springframework.security.oauth2.client.user.DefaultOAuth2UserService;
|
||||
import org.springframework.security.oauth2.client.user.OAuth2UserService;
|
||||
import org.springframework.security.oauth2.core.user.OAuth2User;
|
||||
|
@ -55,6 +58,7 @@ import java.util.Collection;
|
|||
* @see OidcUser
|
||||
*/
|
||||
public class OAuth2UserAuthenticationProvider implements AuthenticationProvider {
|
||||
private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
|
||||
private final OAuth2UserService userService;
|
||||
private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
|
||||
|
||||
|
@ -65,11 +69,18 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
|
|||
|
||||
@Override
|
||||
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
|
||||
OAuth2UserAuthenticationToken oauth2UserAuthentication = (OAuth2UserAuthenticationToken) authentication;
|
||||
OAuth2UserAuthenticationToken userAuthentication = (OAuth2UserAuthenticationToken) authentication;
|
||||
OAuth2ClientAuthenticationToken clientAuthentication = userAuthentication.getClientAuthentication();
|
||||
|
||||
OAuth2ClientAuthenticationToken oauth2ClientAuthentication = oauth2UserAuthentication.getClientAuthentication();
|
||||
if (this.userAuthenticated() && this.userAuthenticatedSameProviderAs(clientAuthentication)) {
|
||||
// Create a new user authentication (using same principal)
|
||||
// but with a different client authentication association
|
||||
return this.createUserAuthentication(
|
||||
(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(),
|
||||
clientAuthentication);
|
||||
}
|
||||
|
||||
OAuth2User oauth2User = this.userService.loadUser(oauth2ClientAuthentication);
|
||||
OAuth2User oauth2User = this.userService.loadUser(clientAuthentication);
|
||||
|
||||
Collection<? extends GrantedAuthority> mappedAuthorities =
|
||||
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
|
||||
|
@ -77,12 +88,12 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
|
|||
OAuth2UserAuthenticationToken authenticationResult;
|
||||
if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) {
|
||||
authenticationResult = new OidcUserAuthenticationToken(
|
||||
(OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)oauth2ClientAuthentication);
|
||||
(OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)clientAuthentication);
|
||||
} else {
|
||||
authenticationResult = new OAuth2UserAuthenticationToken(
|
||||
oauth2User, mappedAuthorities, oauth2ClientAuthentication);
|
||||
oauth2User, mappedAuthorities, clientAuthentication);
|
||||
}
|
||||
authenticationResult.setDetails(oauth2ClientAuthentication.getDetails());
|
||||
authenticationResult.setDetails(clientAuthentication.getDetails());
|
||||
|
||||
return authenticationResult;
|
||||
}
|
||||
|
@ -96,4 +107,52 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
|
|||
public boolean supports(Class<?> authentication) {
|
||||
return OAuth2UserAuthenticationToken.class.isAssignableFrom(authentication);
|
||||
}
|
||||
|
||||
private boolean userAuthenticated() {
|
||||
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
|
||||
return currentAuthentication != null &&
|
||||
currentAuthentication instanceof OAuth2UserAuthenticationToken &&
|
||||
currentAuthentication.isAuthenticated();
|
||||
}
|
||||
|
||||
private boolean userAuthenticatedSameProviderAs(OAuth2ClientAuthenticationToken clientAuthentication) {
|
||||
OAuth2UserAuthenticationToken currentUserAuthentication =
|
||||
(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
|
||||
|
||||
String userProviderId = this.providerIdentifierStrategy.getIdentifier(
|
||||
currentUserAuthentication.getClientAuthentication().getClientRegistration());
|
||||
String clientProviderId = this.providerIdentifierStrategy.getIdentifier(
|
||||
clientAuthentication.getClientRegistration());
|
||||
|
||||
return userProviderId.equals(clientProviderId);
|
||||
}
|
||||
|
||||
private OAuth2UserAuthenticationToken createUserAuthentication(
|
||||
OAuth2UserAuthenticationToken currentUserAuthentication,
|
||||
OAuth2ClientAuthenticationToken newClientAuthentication) {
|
||||
|
||||
if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) {
|
||||
return new OidcUserAuthenticationToken(
|
||||
(OidcUser) currentUserAuthentication.getPrincipal(),
|
||||
currentUserAuthentication.getAuthorities(),
|
||||
newClientAuthentication);
|
||||
} else {
|
||||
return new OAuth2UserAuthenticationToken(
|
||||
(OAuth2User)currentUserAuthentication.getPrincipal(),
|
||||
currentUserAuthentication.getAuthorities(),
|
||||
newClientAuthentication);
|
||||
}
|
||||
}
|
||||
|
||||
private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy<String> {
|
||||
|
||||
@Override
|
||||
public String getIdentifier(ClientRegistration clientRegistration) {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]");
|
||||
builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]");
|
||||
builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]");
|
||||
return builder.toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,23 +18,17 @@ package org.springframework.security.oauth2.client.web;
|
|||
import org.springframework.security.authentication.AuthenticationManager;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.AuthenticationException;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProvider;
|
||||
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
|
||||
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken;
|
||||
import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||
import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest;
|
||||
import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
|
||||
import org.springframework.security.oauth2.core.user.OAuth2User;
|
||||
import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken;
|
||||
import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken;
|
||||
import org.springframework.security.oauth2.oidc.core.user.OidcUser;
|
||||
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
|
||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||
import org.springframework.util.Assert;
|
||||
|
@ -81,7 +75,6 @@ import java.io.IOException;
|
|||
public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
|
||||
public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code";
|
||||
private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
|
||||
private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
|
||||
private AuthorizationResponseMatcher authorizationResponseMatcher;
|
||||
private ClientRegistrationRepository clientRegistrationRepository;
|
||||
private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
|
||||
|
@ -135,20 +128,8 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
|
|||
OAuth2ClientAuthenticationToken oauth2ClientAuthentication =
|
||||
(OAuth2ClientAuthenticationToken)this.getAuthenticationManager().authenticate(authorizationCodeAuthentication);
|
||||
|
||||
OAuth2UserAuthenticationToken oauth2UserAuthentication;
|
||||
if (this.authenticated() && this.authenticatedSameProviderAs(oauth2ClientAuthentication)) {
|
||||
// Create a new user authentication (using same principal)
|
||||
// but with a different client authentication association
|
||||
oauth2UserAuthentication = (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
|
||||
oauth2UserAuthentication = this.createUserAuthentication(oauth2UserAuthentication, oauth2ClientAuthentication);
|
||||
} else {
|
||||
// Authenticate the user... the user needs to be authenticated
|
||||
// before we can associate the client authentication to the user
|
||||
oauth2UserAuthentication = (OAuth2UserAuthenticationToken)this.getAuthenticationManager().authenticate(
|
||||
this.createUserAuthentication(oauth2ClientAuthentication));
|
||||
}
|
||||
|
||||
return oauth2UserAuthentication;
|
||||
return this.getAuthenticationManager().authenticate(
|
||||
new OAuth2UserAuthenticationToken(oauth2ClientAuthentication));
|
||||
}
|
||||
|
||||
public final RequestMatcher getAuthorizationResponseMatcher() {
|
||||
|
@ -171,50 +152,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
|
|||
this.authorizationRequestRepository = authorizationRequestRepository;
|
||||
}
|
||||
|
||||
private boolean authenticated() {
|
||||
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
|
||||
return currentAuthentication != null &&
|
||||
currentAuthentication instanceof OAuth2UserAuthenticationToken &&
|
||||
currentAuthentication.isAuthenticated();
|
||||
}
|
||||
|
||||
private boolean authenticatedSameProviderAs(OAuth2ClientAuthenticationToken oauth2ClientAuthentication) {
|
||||
OAuth2UserAuthenticationToken userAuthentication =
|
||||
(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
|
||||
|
||||
String userProviderId = this.providerIdentifierStrategy.getIdentifier(
|
||||
userAuthentication.getClientAuthentication().getClientRegistration());
|
||||
String clientProviderId = this.providerIdentifierStrategy.getIdentifier(
|
||||
oauth2ClientAuthentication.getClientRegistration());
|
||||
|
||||
return userProviderId.equals(clientProviderId);
|
||||
}
|
||||
|
||||
private OAuth2UserAuthenticationToken createUserAuthentication(OAuth2ClientAuthenticationToken clientAuthentication) {
|
||||
if (OidcClientAuthenticationToken.class.isAssignableFrom(clientAuthentication.getClass())) {
|
||||
return new OidcUserAuthenticationToken((OidcClientAuthenticationToken)clientAuthentication);
|
||||
} else {
|
||||
return new OAuth2UserAuthenticationToken(clientAuthentication);
|
||||
}
|
||||
}
|
||||
|
||||
private OAuth2UserAuthenticationToken createUserAuthentication(
|
||||
OAuth2UserAuthenticationToken currentUserAuthentication,
|
||||
OAuth2ClientAuthenticationToken newClientAuthentication) {
|
||||
|
||||
if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) {
|
||||
return new OidcUserAuthenticationToken(
|
||||
(OidcUser) currentUserAuthentication.getPrincipal(),
|
||||
currentUserAuthentication.getAuthorities(),
|
||||
newClientAuthentication);
|
||||
} else {
|
||||
return new OAuth2UserAuthenticationToken(
|
||||
(OAuth2User)currentUserAuthentication.getPrincipal(),
|
||||
currentUserAuthentication.getAuthorities(),
|
||||
newClientAuthentication);
|
||||
}
|
||||
}
|
||||
|
||||
private static class AuthorizationResponseMatcher implements RequestMatcher {
|
||||
private final String baseUri;
|
||||
|
||||
|
@ -266,16 +203,4 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy<String> {
|
||||
|
||||
@Override
|
||||
public String getIdentifier(ClientRegistration clientRegistration) {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]");
|
||||
builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]");
|
||||
builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]");
|
||||
return builder.toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -128,7 +128,7 @@ public class AuthorizationCodeAuthenticationFilterTests {
|
|||
ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
|
||||
Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
|
||||
authenticationArgCaptor.capture());
|
||||
Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication);
|
||||
Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(clientAuthentication);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue