Extract authentication logic from AuthorizationCodeAuthenticationFilter

Fixes gh-4590
This commit is contained in:
Joe Grandja 2017-10-07 21:26:26 -04:00
parent d13c3a040c
commit 97c938e7f3
8 changed files with 77 additions and 133 deletions

View File

@ -21,6 +21,9 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.client.token.InMemoryAccessTokenRepository;
import org.springframework.security.oauth2.client.token.SecurityTokenRepository;
import org.springframework.security.oauth2.core.AccessToken;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse;
import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken;
import org.springframework.util.Assert;
@ -32,7 +35,7 @@ import org.springframework.util.Assert;
*
* <p>
* The {@link AuthorizationCodeAuthenticationProvider} uses an {@link AuthorizationGrantAuthenticator}
* to authenticate the {@link AuthorizationCodeAuthenticationToken#getAuthorizationCode()} and ultimately
* to authenticate the <i>authorization code</i> credential and ultimately
* return an <i>&quot;Authorized Client&quot;</i> as an {@link OAuth2ClientAuthenticationToken}.
*
* @author Joe Grandja
@ -49,6 +52,8 @@ import org.springframework.util.Assert;
* @see <a target="_blank" href="http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse">Section 3.1.3.3 OpenID Connect Token Response</a>
*/
public class AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
private static final String INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE = "invalid_redirect_uri_parameter";
private final AuthorizationGrantAuthenticator<AuthorizationCodeAuthenticationToken> authorizationCodeAuthenticator;
private SecurityTokenRepository<AccessToken> accessTokenRepository = new InMemoryAccessTokenRepository();
@ -64,6 +69,24 @@ public class AuthorizationCodeAuthenticationProvider implements AuthenticationPr
AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
(AuthorizationCodeAuthenticationToken) authentication;
AuthorizationRequest authorizationRequest = authorizationCodeAuthentication.getAuthorizationRequest();
AuthorizationResponse authorizationResponse = authorizationCodeAuthentication.getAuthorizationResponse();
if (authorizationResponse.statusError()) {
throw new OAuth2AuthenticationException(
authorizationResponse.getError(), authorizationResponse.getError().toString());
}
if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
if (!authorizationResponse.getRedirectUri().equals(authorizationRequest.getRedirectUri())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
OAuth2ClientAuthenticationToken oauth2ClientAuthentication =
this.authorizationCodeAuthenticator.authenticate(authorizationCodeAuthentication);

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse;
import org.springframework.util.Assert;
/**
@ -28,38 +29,37 @@ import org.springframework.util.Assert;
* @since 5.0
* @see AuthorizationGrantAuthenticationToken
* @see ClientRegistration
* @see AuthorizationRequest
* @see AuthorizationResponse
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.3.1">Section 1.3.1 Authorization Code Grant</a>
*/
public class AuthorizationCodeAuthenticationToken extends AuthorizationGrantAuthenticationToken {
private final String authorizationCode;
private final ClientRegistration clientRegistration;
private final AuthorizationRequest authorizationRequest;
private final AuthorizationResponse authorizationResponse;
public AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration,
AuthorizationRequest authorizationRequest,
AuthorizationResponse authorizationResponse) {
public AuthorizationCodeAuthenticationToken(String authorizationCode,
ClientRegistration clientRegistration,
AuthorizationRequest authorizationRequest) {
super(AuthorizationGrantType.AUTHORIZATION_CODE);
Assert.hasText(authorizationCode, "authorizationCode cannot be empty");
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
this.authorizationCode = authorizationCode;
Assert.notNull(authorizationResponse, "authorizationResponse cannot be null");
this.clientRegistration = clientRegistration;
this.authorizationRequest = authorizationRequest;
this.authorizationResponse = authorizationResponse;
this.setAuthenticated(false);
}
@Override
public Object getPrincipal() {
return this.getClientRegistration().getClientId();
return "";
}
@Override
public Object getCredentials() {
return this.getAuthorizationCode();
}
public String getAuthorizationCode() {
return this.authorizationCode;
return "";
}
public ClientRegistration getClientRegistration() {
@ -69,4 +69,8 @@ public class AuthorizationCodeAuthenticationToken extends AuthorizationGrantAuth
public AuthorizationRequest getAuthorizationRequest() {
return this.authorizationRequest;
}
public AuthorizationResponse getAuthorizationResponse() {
return this.authorizationResponse;
}
}

View File

@ -82,8 +82,6 @@ import java.io.IOException;
public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code";
private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
private static final String INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE = "invalid_redirect_uri_parameter";
private final AuthorizationResponseConverter authorizationResponseConverter = new AuthorizationResponseConverter();
private ClientRegistrationRepository clientRegistrationRepository;
private RequestMatcher authorizationResponseMatcher = new AuthorizationResponseMatcher();
@ -98,16 +96,16 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
throws AuthenticationException, IOException, ServletException {
AuthorizationRequest authorizationRequest = this.getAuthorizationRequestRepository().loadAuthorizationRequest(request);
if (authorizationRequest == null) {
OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
this.getAuthorizationRequestRepository().removeAuthorizationRequest(request);
AuthorizationResponse authorizationResponse = this.authorizationResponseConverter.apply(request);
if (authorizationResponse.statusError()) {
this.getAuthorizationRequestRepository().removeAuthorizationRequest(request);
throw new OAuth2AuthenticationException(
authorizationResponse.getError(), authorizationResponse.getError().toString());
}
AuthorizationRequest matchingAuthorizationRequest = this.resolveAuthorizationRequest(request);
String registrationId = (String)matchingAuthorizationRequest.getAdditionalParameters().get(OAuth2Parameter.REGISTRATION_ID);
String registrationId = (String)authorizationRequest.getAdditionalParameters().get(OAuth2Parameter.REGISTRATION_ID);
ClientRegistration clientRegistration = this.getClientRegistrationRepository().findByRegistrationId(registrationId);
// The clientRegistration.redirectUri may contain Uri template variables, whether it's configured by
@ -116,13 +114,13 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
// The resulting redirectUri used for the authorization request and saved within the AuthorizationRequestRepository
// MUST BE the same one used to complete the authorization code flow.
// Therefore, we'll create a copy of the clientRegistration and override the redirectUri
// with the one contained in matchingAuthorizationRequest.
// with the one contained in authorizationRequest.
clientRegistration = new ClientRegistration.Builder(clientRegistration)
.redirectUri(matchingAuthorizationRequest.getRedirectUri())
.redirectUri(authorizationRequest.getRedirectUri())
.build();
AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = new AuthorizationCodeAuthenticationToken(
authorizationResponse.getCode(), clientRegistration, matchingAuthorizationRequest);
clientRegistration, authorizationRequest, authorizationResponse);
authorizationCodeAuthentication.setDetails(this.authenticationDetailsSource.buildDetails(request));
OAuth2ClientAuthenticationToken oauth2ClientAuthentication =
@ -172,31 +170,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
this.authorizationRequestRepository = authorizationRequestRepository;
}
private AuthorizationRequest resolveAuthorizationRequest(HttpServletRequest request) {
AuthorizationRequest authorizationRequest =
this.getAuthorizationRequestRepository().loadAuthorizationRequest(request);
if (authorizationRequest == null) {
OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
this.getAuthorizationRequestRepository().removeAuthorizationRequest(request);
this.assertMatchingAuthorizationRequest(request, authorizationRequest);
return authorizationRequest;
}
private void assertMatchingAuthorizationRequest(HttpServletRequest request, AuthorizationRequest authorizationRequest) {
String state = request.getParameter(OAuth2Parameter.STATE);
if (!authorizationRequest.getState().equals(state)) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
if (!request.getRequestURL().toString().equals(authorizationRequest.getRedirectUri())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private boolean authenticated() {
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
return currentAuthentication != null &&

View File

@ -38,15 +38,18 @@ public final class AuthorizationResponseConverter implements Function<HttpServle
String code = request.getParameter(OAuth2Parameter.CODE);
String errorCode = request.getParameter(OAuth2Parameter.ERROR);
String state = request.getParameter(OAuth2Parameter.STATE);
String redirectUri = request.getRequestURL().toString();
if (StringUtils.hasText(code)) {
return AuthorizationResponse.success(code)
.redirectUri(redirectUri)
.state(state)
.build();
} else if (StringUtils.hasText(errorCode)) {
String description = request.getParameter(OAuth2Parameter.ERROR_DESCRIPTION);
String uri = request.getParameter(OAuth2Parameter.ERROR_URI);
return AuthorizationResponse.error(errorCode)
.redirectUri(redirectUri)
.errorDescription(description)
.errorUri(uri)
.state(state)

View File

@ -76,7 +76,8 @@ public class NimbusAuthorizationCodeTokenExchanger implements AuthorizationGrant
ClientRegistration clientRegistration = authorizationCodeAuthenticationToken.getClientRegistration();
// Build the authorization code grant request for the token endpoint
AuthorizationCode authorizationCode = new AuthorizationCode(authorizationCodeAuthenticationToken.getAuthorizationCode());
AuthorizationCode authorizationCode = new AuthorizationCode(
authorizationCodeAuthenticationToken.getAuthorizationResponse().getCode());
URI redirectUri = this.toURI(clientRegistration.getRedirectUri());
AuthorizationGrant authorizationCodeGrant = new AuthorizationCodeGrant(authorizationCode, redirectUri);
URI tokenUri = this.toURI(clientRegistration.getProviderDetails().getTokenUri());

View File

@ -152,55 +152,6 @@ public class AuthorizationCodeAuthenticationFilterTests {
verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(filter, failureHandler, "authorization_request_not_found");
}
@Test
public void doFilterWhenAuthorizationCodeSuccessResponseWithInvalidStateParamThenThrowOAuth2AuthenticationExceptionInvalidStateParameter() throws Exception {
ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
AuthorizationCodeAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
filter.setAuthenticationFailureHandler(failureHandler);
AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
filter.setAuthorizationRequestRepository(authorizationRequestRepository);
MockHttpServletRequest request = this.setupRequest(clientRegistration);
String authCode = "some code";
String state = "some other state";
request.addParameter(OAuth2Parameter.CODE, authCode);
request.addParameter(OAuth2Parameter.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse();
setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, "some state");
FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain);
verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(filter, failureHandler, "invalid_state_parameter");
}
@Test
public void doFilterWhenAuthorizationCodeSuccessResponseWithInvalidRedirectUriParamThenThrowOAuth2AuthenticationExceptionInvalidRedirectUriParameter() throws Exception {
ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
AuthorizationCodeAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
filter.setAuthenticationFailureHandler(failureHandler);
AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
filter.setAuthorizationRequestRepository(authorizationRequestRepository);
MockHttpServletRequest request = this.setupRequest(clientRegistration);
request.setRequestURI(request.getRequestURI() + "-other");
String authCode = "some code";
String state = "some state";
request.addParameter(OAuth2Parameter.CODE, authCode);
request.addParameter(OAuth2Parameter.STATE, state);
MockHttpServletResponse response = new MockHttpServletResponse();
setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state);
FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain);
verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(filter, failureHandler, "invalid_redirect_uri_parameter");
}
private void verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(AuthorizationCodeAuthenticationFilter filter,
AuthenticationFailureHandler failureHandler,
String errorCode) throws Exception {

View File

@ -27,21 +27,26 @@ import org.springframework.util.StringUtils;
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.2">Section 4.1.2 Authorization Response</a>
*/
public final class AuthorizationResponse {
private String code;
private String redirectUri;
private String state;
private String code;
private OAuth2Error error;
private AuthorizationResponse() {
}
public String getCode() {
return this.code;
public String getRedirectUri() {
return this.redirectUri;
}
public String getState() {
return this.state;
}
public String getCode() {
return this.code;
}
public OAuth2Error getError() {
return this.error;
}
@ -65,8 +70,9 @@ public final class AuthorizationResponse {
}
public static class Builder {
private String code;
private String redirectUri;
private String state;
private String code;
private String errorCode;
private String errorDescription;
private String errorUri;
@ -74,8 +80,8 @@ public final class AuthorizationResponse {
private Builder() {
}
public Builder code(String code) {
this.code = code;
public Builder redirectUri(String redirectUri) {
this.redirectUri = redirectUri;
return this;
}
@ -84,6 +90,11 @@ public final class AuthorizationResponse {
return this;
}
public Builder code(String code) {
this.code = code;
return this;
}
public Builder errorCode(String errorCode) {
this.errorCode = errorCode;
return this;
@ -103,14 +114,17 @@ public final class AuthorizationResponse {
if (StringUtils.hasText(this.code) && StringUtils.hasText(this.errorCode)) {
throw new IllegalArgumentException("code and errorCode cannot both be set");
}
Assert.hasText(this.redirectUri, "redirectUri cannot be empty");
AuthorizationResponse authorizationResponse = new AuthorizationResponse();
authorizationResponse.redirectUri = this.redirectUri;
authorizationResponse.state = this.state;
if (StringUtils.hasText(this.code)) {
authorizationResponse.code = this.code;
} else {
authorizationResponse.error = new OAuth2Error(
this.errorCode, this.errorDescription, this.errorUri);
}
authorizationResponse.state = this.state;
return authorizationResponse;
}
}

View File

@ -44,7 +44,6 @@ import org.springframework.security.oauth2.client.web.AuthorizationCodeAuthentic
import org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.AuthorizationGrantTokenExchanger;
import org.springframework.security.oauth2.core.AccessToken;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
import org.springframework.security.oauth2.core.endpoint.ResponseType;
import org.springframework.security.oauth2.core.endpoint.TokenResponse;
@ -282,30 +281,6 @@ public class OAuth2LoginApplicationTests {
assertThat(errorElement.asText()).contains("invalid_redirect_uri_parameter");
}
@Test
public void requestAuthorizationCodeGrantWhenStandardErrorCodeResponseThenDisplayLoginPageWithError() throws Exception {
HtmlPage page = this.webClient.getPage("/");
URL loginPageUrl = page.getBaseURL();
URL loginErrorPageUrl = new URL(loginPageUrl.toString() + "?error");
String error = OAuth2Error.INVALID_CLIENT_ERROR_CODE;
String state = "state";
String redirectUri = AUTHORIZE_BASE_URL + "/" + this.githubClientRegistration.getRegistrationId();
String authorizationResponseUri =
UriComponentsBuilder.fromHttpUrl(redirectUri)
.queryParam(OAuth2Parameter.ERROR, error)
.queryParam(OAuth2Parameter.STATE, state)
.build().encode().toUriString();
page = this.webClient.getPage(new URL(authorizationResponseUri));
assertThat(page.getBaseURL()).isEqualTo(loginErrorPageUrl);
HtmlElement errorElement = page.getBody().getFirstByXPath("p");
assertThat(errorElement).isNotNull();
assertThat(errorElement.asText()).contains(error);
}
private void assertLoginPage(HtmlPage page) throws Exception {
assertThat(page.getTitleText()).isEqualTo("Login Page");