Add converter for authentication result in OAuth2LoginAuthenticationFilter

Closes gh-10033
This commit is contained in:
Steve Riesenberg 2021-07-01 12:55:12 -05:00 committed by Steve Riesenberg
parent fc553bf19a
commit 6d6dc113d8
2 changed files with 79 additions and 5 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.web;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
@ -111,6 +112,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
private Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter = this::createAuthenticationResult;
/**
* Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided
* parameters.
@ -190,9 +193,9 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
authenticationRequest.setDetails(authenticationDetails);
OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) this
.getAuthenticationManager().authenticate(authenticationRequest);
OAuth2AuthenticationToken oauth2Authentication = new OAuth2AuthenticationToken(
authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
authenticationResult.getClientRegistration().getRegistrationId());
OAuth2AuthenticationToken oauth2Authentication = this.authenticationResultConverter
.convert(authenticationResult);
Assert.notNull(oauth2Authentication, "authentication result cannot be null");
oauth2Authentication.setDetails(authenticationDetails);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
authenticationResult.getClientRegistration(), oauth2Authentication.getName(),
@ -213,4 +216,22 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
this.authorizationRequestRepository = authorizationRequestRepository;
}
/**
* Sets the converter responsible for converting from
* {@link OAuth2LoginAuthenticationToken} to {@link OAuth2AuthenticationToken}
* authentication result.
* @param authenticationResultConverter the converter for
* {@link OAuth2AuthenticationToken}'s
*/
public final void setAuthenticationResultConverter(
Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter) {
Assert.notNull(authenticationResultConverter, "authenticationResultConverter cannot be null");
this.authenticationResultConverter = authenticationResultConverter;
}
private OAuth2AuthenticationToken createAuthenticationResult(OAuth2LoginAuthenticationToken authenticationResult) {
return new OAuth2AuthenticationToken(authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
authenticationResult.getClientRegistration().getRegistrationId());
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.web;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
@ -33,10 +34,12 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@ -152,6 +155,12 @@ public class OAuth2LoginAuthenticationFilterTests {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null));
}
// gh-10033
@Test
public void setAuthenticationResultConverterWhenAuthenticationResultConverterIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationResultConverter(null));
}
@Test
public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
String requestUri = "/path";
@ -416,6 +425,41 @@ public class OAuth2LoginAuthenticationFilterTests {
assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails);
}
// gh-10033
@Test
public void attemptAuthenticationWhenAuthenticationResultIsNullThenIllegalArgumentException() throws Exception {
this.filter.setAuthenticationResultConverter((authentication) -> null);
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(request, response, this.registration1, state);
this.setUpAuthenticationResult(this.registration1);
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.attemptAuthentication(request, response));
}
// gh-10033
@Test
public void attemptAuthenticationWhenAuthenticationResultConverterSetThenUsed() {
this.filter.setAuthenticationResultConverter(
(authentication) -> new CustomOAuth2AuthenticationToken(authentication.getPrincipal(),
authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId()));
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
String state = "state";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse();
this.setUpAuthorizationRequest(request, response, this.registration1, state);
this.setUpAuthenticationResult(this.registration1);
Authentication authenticationResult = this.filter.attemptAuthentication(request, response);
assertThat(authenticationResult).isInstanceOf(CustomOAuth2AuthenticationToken.class);
}
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
ClientRegistration registration, String state) {
Map<String, Object> attributes = new HashMap<>();
@ -454,4 +498,13 @@ public class OAuth2LoginAuthenticationFilterTests {
given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication);
}
private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken {
CustomOAuth2AuthenticationToken(OAuth2User principal, Collection<? extends GrantedAuthority> authorities,
String authorizedClientRegistrationId) {
super(principal, authorities, authorizedClientRegistrationId);
}
}
}