Add converter for authentication result in OAuth2LoginAuthenticationFilter
Closes gh-10033
This commit is contained in:
parent
fc553bf19a
commit
6d6dc113d8
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2021 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.web;
|
|||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.security.authentication.AuthenticationManager;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.AuthenticationException;
|
||||
|
@ -111,6 +112,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
|
|||
|
||||
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
|
||||
|
||||
private Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter = this::createAuthenticationResult;
|
||||
|
||||
/**
|
||||
* Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided
|
||||
* parameters.
|
||||
|
@ -190,9 +193,9 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
|
|||
authenticationRequest.setDetails(authenticationDetails);
|
||||
OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) this
|
||||
.getAuthenticationManager().authenticate(authenticationRequest);
|
||||
OAuth2AuthenticationToken oauth2Authentication = new OAuth2AuthenticationToken(
|
||||
authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
|
||||
authenticationResult.getClientRegistration().getRegistrationId());
|
||||
OAuth2AuthenticationToken oauth2Authentication = this.authenticationResultConverter
|
||||
.convert(authenticationResult);
|
||||
Assert.notNull(oauth2Authentication, "authentication result cannot be null");
|
||||
oauth2Authentication.setDetails(authenticationDetails);
|
||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
|
||||
authenticationResult.getClientRegistration(), oauth2Authentication.getName(),
|
||||
|
@ -213,4 +216,22 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
|
|||
this.authorizationRequestRepository = authorizationRequestRepository;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the converter responsible for converting from
|
||||
* {@link OAuth2LoginAuthenticationToken} to {@link OAuth2AuthenticationToken}
|
||||
* authentication result.
|
||||
* @param authenticationResultConverter the converter for
|
||||
* {@link OAuth2AuthenticationToken}'s
|
||||
*/
|
||||
public final void setAuthenticationResultConverter(
|
||||
Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter) {
|
||||
Assert.notNull(authenticationResultConverter, "authenticationResultConverter cannot be null");
|
||||
this.authenticationResultConverter = authenticationResultConverter;
|
||||
}
|
||||
|
||||
private OAuth2AuthenticationToken createAuthenticationResult(OAuth2LoginAuthenticationToken authenticationResult) {
|
||||
return new OAuth2AuthenticationToken(authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
|
||||
authenticationResult.getClientRegistration().getRegistrationId());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
* Copyright 2002-2021 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.springframework.security.oauth2.client.web;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -33,10 +34,12 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
|
|||
import org.springframework.security.authentication.AuthenticationManager;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.AuthenticationException;
|
||||
import org.springframework.security.core.GrantedAuthority;
|
||||
import org.springframework.security.core.authority.AuthorityUtils;
|
||||
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
|
||||
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
||||
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
||||
|
@ -152,6 +155,12 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|||
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null));
|
||||
}
|
||||
|
||||
// gh-10033
|
||||
@Test
|
||||
public void setAuthenticationResultConverterWhenAuthenticationResultConverterIsNullThenThrowIllegalArgumentException() {
|
||||
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationResultConverter(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
|
||||
String requestUri = "/path";
|
||||
|
@ -416,6 +425,41 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|||
assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails);
|
||||
}
|
||||
|
||||
// gh-10033
|
||||
@Test
|
||||
public void attemptAuthenticationWhenAuthenticationResultIsNullThenIllegalArgumentException() throws Exception {
|
||||
this.filter.setAuthenticationResultConverter((authentication) -> null);
|
||||
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
|
||||
String state = "state";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, state);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1, state);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.attemptAuthentication(request, response));
|
||||
}
|
||||
|
||||
// gh-10033
|
||||
@Test
|
||||
public void attemptAuthenticationWhenAuthenticationResultConverterSetThenUsed() {
|
||||
this.filter.setAuthenticationResultConverter(
|
||||
(authentication) -> new CustomOAuth2AuthenticationToken(authentication.getPrincipal(),
|
||||
authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId()));
|
||||
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
|
||||
String state = "state";
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
||||
request.addParameter(OAuth2ParameterNames.STATE, state);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
this.setUpAuthorizationRequest(request, response, this.registration1, state);
|
||||
this.setUpAuthenticationResult(this.registration1);
|
||||
Authentication authenticationResult = this.filter.attemptAuthentication(request, response);
|
||||
assertThat(authenticationResult).isInstanceOf(CustomOAuth2AuthenticationToken.class);
|
||||
}
|
||||
|
||||
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
||||
ClientRegistration registration, String state) {
|
||||
Map<String, Object> attributes = new HashMap<>();
|
||||
|
@ -454,4 +498,13 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|||
given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication);
|
||||
}
|
||||
|
||||
private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken {
|
||||
|
||||
CustomOAuth2AuthenticationToken(OAuth2User principal, Collection<? extends GrantedAuthority> authorities,
|
||||
String authorizedClientRegistrationId) {
|
||||
super(principal, authorities, authorizedClientRegistrationId);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue