Move logic from AuthorizationCodeAuthenticationFilter to OAuth2UserAuthenticationProvider

This commit is contained in:
Joe Grandja 2017-10-11 17:23:59 -04:00
parent 18df9a869e
commit df474e04d8
3 changed files with 68 additions and 84 deletions

View File

@ -20,6 +20,9 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy;
import org.springframework.security.oauth2.client.user.DefaultOAuth2UserService;
import org.springframework.security.oauth2.client.user.OAuth2UserService;
import org.springframework.security.oauth2.core.user.OAuth2User;
@ -55,6 +58,7 @@ import java.util.Collection;
* @see OidcUser
*/
public class OAuth2UserAuthenticationProvider implements AuthenticationProvider {
private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
private final OAuth2UserService userService;
private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
@ -65,11 +69,18 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2UserAuthenticationToken oauth2UserAuthentication = (OAuth2UserAuthenticationToken) authentication;
OAuth2UserAuthenticationToken userAuthentication = (OAuth2UserAuthenticationToken) authentication;
OAuth2ClientAuthenticationToken clientAuthentication = userAuthentication.getClientAuthentication();
OAuth2ClientAuthenticationToken oauth2ClientAuthentication = oauth2UserAuthentication.getClientAuthentication();
if (this.userAuthenticated() && this.userAuthenticatedSameProviderAs(clientAuthentication)) {
// Create a new user authentication (using same principal)
// but with a different client authentication association
return this.createUserAuthentication(
(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(),
clientAuthentication);
}
OAuth2User oauth2User = this.userService.loadUser(oauth2ClientAuthentication);
OAuth2User oauth2User = this.userService.loadUser(clientAuthentication);
Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
@ -77,12 +88,12 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
OAuth2UserAuthenticationToken authenticationResult;
if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) {
authenticationResult = new OidcUserAuthenticationToken(
(OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)oauth2ClientAuthentication);
(OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)clientAuthentication);
} else {
authenticationResult = new OAuth2UserAuthenticationToken(
oauth2User, mappedAuthorities, oauth2ClientAuthentication);
oauth2User, mappedAuthorities, clientAuthentication);
}
authenticationResult.setDetails(oauth2ClientAuthentication.getDetails());
authenticationResult.setDetails(clientAuthentication.getDetails());
return authenticationResult;
}
@ -96,4 +107,52 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
public boolean supports(Class<?> authentication) {
return OAuth2UserAuthenticationToken.class.isAssignableFrom(authentication);
}
private boolean userAuthenticated() {
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
return currentAuthentication != null &&
currentAuthentication instanceof OAuth2UserAuthenticationToken &&
currentAuthentication.isAuthenticated();
}
private boolean userAuthenticatedSameProviderAs(OAuth2ClientAuthenticationToken clientAuthentication) {
OAuth2UserAuthenticationToken currentUserAuthentication =
(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
String userProviderId = this.providerIdentifierStrategy.getIdentifier(
currentUserAuthentication.getClientAuthentication().getClientRegistration());
String clientProviderId = this.providerIdentifierStrategy.getIdentifier(
clientAuthentication.getClientRegistration());
return userProviderId.equals(clientProviderId);
}
private OAuth2UserAuthenticationToken createUserAuthentication(
OAuth2UserAuthenticationToken currentUserAuthentication,
OAuth2ClientAuthenticationToken newClientAuthentication) {
if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) {
return new OidcUserAuthenticationToken(
(OidcUser) currentUserAuthentication.getPrincipal(),
currentUserAuthentication.getAuthorities(),
newClientAuthentication);
} else {
return new OAuth2UserAuthenticationToken(
(OAuth2User)currentUserAuthentication.getPrincipal(),
currentUserAuthentication.getAuthorities(),
newClientAuthentication);
}
}
private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy<String> {
@Override
public String getIdentifier(ClientRegistration clientRegistration) {
StringBuilder builder = new StringBuilder();
builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]");
builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]");
builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]");
return builder.toString();
}
}
}

View File

@ -18,23 +18,17 @@ package org.springframework.security.oauth2.client.web;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException;
import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken;
import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken;
import org.springframework.security.oauth2.oidc.core.user.OidcUser;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
@ -81,7 +75,6 @@ import java.io.IOException;
public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code";
private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
private AuthorizationResponseMatcher authorizationResponseMatcher;
private ClientRegistrationRepository clientRegistrationRepository;
private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
@ -135,20 +128,8 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
OAuth2ClientAuthenticationToken oauth2ClientAuthentication =
(OAuth2ClientAuthenticationToken)this.getAuthenticationManager().authenticate(authorizationCodeAuthentication);
OAuth2UserAuthenticationToken oauth2UserAuthentication;
if (this.authenticated() && this.authenticatedSameProviderAs(oauth2ClientAuthentication)) {
// Create a new user authentication (using same principal)
// but with a different client authentication association
oauth2UserAuthentication = (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
oauth2UserAuthentication = this.createUserAuthentication(oauth2UserAuthentication, oauth2ClientAuthentication);
} else {
// Authenticate the user... the user needs to be authenticated
// before we can associate the client authentication to the user
oauth2UserAuthentication = (OAuth2UserAuthenticationToken)this.getAuthenticationManager().authenticate(
this.createUserAuthentication(oauth2ClientAuthentication));
}
return oauth2UserAuthentication;
return this.getAuthenticationManager().authenticate(
new OAuth2UserAuthenticationToken(oauth2ClientAuthentication));
}
public final RequestMatcher getAuthorizationResponseMatcher() {
@ -171,50 +152,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
this.authorizationRequestRepository = authorizationRequestRepository;
}
private boolean authenticated() {
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
return currentAuthentication != null &&
currentAuthentication instanceof OAuth2UserAuthenticationToken &&
currentAuthentication.isAuthenticated();
}
private boolean authenticatedSameProviderAs(OAuth2ClientAuthenticationToken oauth2ClientAuthentication) {
OAuth2UserAuthenticationToken userAuthentication =
(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
String userProviderId = this.providerIdentifierStrategy.getIdentifier(
userAuthentication.getClientAuthentication().getClientRegistration());
String clientProviderId = this.providerIdentifierStrategy.getIdentifier(
oauth2ClientAuthentication.getClientRegistration());
return userProviderId.equals(clientProviderId);
}
private OAuth2UserAuthenticationToken createUserAuthentication(OAuth2ClientAuthenticationToken clientAuthentication) {
if (OidcClientAuthenticationToken.class.isAssignableFrom(clientAuthentication.getClass())) {
return new OidcUserAuthenticationToken((OidcClientAuthenticationToken)clientAuthentication);
} else {
return new OAuth2UserAuthenticationToken(clientAuthentication);
}
}
private OAuth2UserAuthenticationToken createUserAuthentication(
OAuth2UserAuthenticationToken currentUserAuthentication,
OAuth2ClientAuthenticationToken newClientAuthentication) {
if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) {
return new OidcUserAuthenticationToken(
(OidcUser) currentUserAuthentication.getPrincipal(),
currentUserAuthentication.getAuthorities(),
newClientAuthentication);
} else {
return new OAuth2UserAuthenticationToken(
(OAuth2User)currentUserAuthentication.getPrincipal(),
currentUserAuthentication.getAuthorities(),
newClientAuthentication);
}
}
private static class AuthorizationResponseMatcher implements RequestMatcher {
private final String baseUri;
@ -266,16 +203,4 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
}
}
}
private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy<String> {
@Override
public String getIdentifier(ClientRegistration clientRegistration) {
StringBuilder builder = new StringBuilder();
builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]");
builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]");
builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]");
return builder.toString();
}
}
}

View File

@ -128,7 +128,7 @@ public class AuthorizationCodeAuthenticationFilterTests {
ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
authenticationArgCaptor.capture());
Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication);
Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(clientAuthentication);
}
@Test