Use param matching for Authorization Response

Fixes gh-4576
This commit is contained in:
Joe Grandja 2017-09-26 15:24:08 -04:00
parent d191bcc8ac
commit 9a8ddebc94
7 changed files with 58 additions and 27 deletions

View File

@ -36,7 +36,6 @@ import org.springframework.security.oauth2.core.AccessToken;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.oidc.client.user.OidcUserService;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.matcher.RequestVariablesExtractor;
import org.springframework.util.Assert;
import java.net.URI;
@ -48,7 +47,7 @@ import java.util.Map;
/**
* @author Joe Grandja
*/
final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecurityBuilder<H>, R extends RequestMatcher & RequestVariablesExtractor> extends
final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecurityBuilder<H>, R extends RequestMatcher> extends
AbstractAuthenticationFilterConfigurer<H, AuthorizationCodeAuthenticationFilterConfigurer<H, R>, AuthorizationCodeAuthenticationProcessingFilter> {
private R authorizationResponseMatcher;

View File

@ -166,7 +166,7 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
private RedirectionEndpointConfig() {
}
public <R extends RequestMatcher & RequestVariablesExtractor> RedirectionEndpointConfig requestMatcher(R authorizationResponseMatcher) {
public <R extends RequestMatcher> RedirectionEndpointConfig requestMatcher(R authorizationResponseMatcher) {
Assert.notNull(authorizationResponseMatcher, "authorizationResponseMatcher cannot be null");
OAuth2LoginConfigurer.this.authorizationCodeAuthenticationFilterConfigurer.authorizationResponseMatcher(authorizationResponseMatcher);
return this;

View File

@ -37,9 +37,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
import org.springframework.security.oauth2.core.endpoint.TokenResponseAttributes;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.matcher.RequestVariablesExtractor;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@ -111,20 +109,18 @@ import java.io.IOException;
*/
public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAuthenticationProcessingFilter {
public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code";
public static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
public static final String DEFAULT_AUTHORIZATION_RESPONSE_URI = DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}";
private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
private static final String INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE = "invalid_redirect_uri_parameter";
private final ErrorResponseAttributesConverter errorResponseConverter = new ErrorResponseAttributesConverter();
private final AuthorizationCodeAuthorizationResponseAttributesConverter authorizationCodeResponseConverter =
new AuthorizationCodeAuthorizationResponseAttributesConverter();
private RequestMatcher authorizationResponseMatcher = new AntPathRequestMatcher(DEFAULT_AUTHORIZATION_RESPONSE_URI);
private ClientRegistrationRepository clientRegistrationRepository;
private RequestMatcher authorizationResponseMatcher = new AuthorizationResponseMatcher();
private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
public AuthorizationCodeAuthenticationProcessingFilter() {
super(DEFAULT_AUTHORIZATION_RESPONSE_URI);
super(new AuthorizationResponseMatcher());
}
@Override
@ -140,17 +136,8 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
}
AuthorizationRequestAttributes matchingAuthorizationRequest = this.resolveAuthorizationRequest(request);
String registrationId = ((RequestVariablesExtractor)this.getAuthorizationResponseMatcher())
.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
ClientRegistration clientRegistration = null;
if (!StringUtils.isEmpty(registrationId)) {
clientRegistration = this.getClientRegistrationRepository().findByRegistrationId(registrationId);
}
if (clientRegistration == null || !matchingAuthorizationRequest.getClientId().equals(clientRegistration.getClientId())) {
OAuth2Error oauth2Error = new OAuth2Error(OAuth2Error.INVALID_REQUEST_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String registrationId = (String)matchingAuthorizationRequest.getAdditionalParameters().get(OAuth2Parameter.REGISTRATION_ID);
ClientRegistration clientRegistration = this.getClientRegistrationRepository().findByRegistrationId(registrationId);
// The clientRegistration.redirectUri may contain Uri template variables, whether it's configured by
// the user or configured by default. In these cases, the redirectUri will be expanded and ultimately changed
@ -180,7 +167,7 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
return this.authorizationResponseMatcher;
}
public final <T extends RequestMatcher & RequestVariablesExtractor> void setAuthorizationResponseMatcher(T authorizationResponseMatcher) {
public final <T extends RequestMatcher> void setAuthorizationResponseMatcher(T authorizationResponseMatcher) {
Assert.notNull(authorizationResponseMatcher, "authorizationResponseMatcher cannot be null");
this.authorizationResponseMatcher = authorizationResponseMatcher;
this.setRequiresAuthenticationRequestMatcher(authorizationResponseMatcher);
@ -228,4 +215,22 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private static class AuthorizationResponseMatcher implements RequestMatcher {
@Override
public boolean matches(HttpServletRequest request) {
return this.successResponse(request) || this.errorResponse(request);
}
private boolean successResponse(HttpServletRequest request) {
return StringUtils.hasText(request.getParameter(OAuth2Parameter.CODE)) &&
StringUtils.hasText(request.getParameter(OAuth2Parameter.STATE));
}
private boolean errorResponse(HttpServletRequest request) {
return StringUtils.hasText(request.getParameter(OAuth2Parameter.ERROR)) &&
StringUtils.hasText(request.getParameter(OAuth2Parameter.STATE));
}
}
}

View File

@ -19,6 +19,7 @@ import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes;
import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@ -122,6 +123,9 @@ public class AuthorizationCodeRequestRedirectFilter extends OncePerRequestFilter
String redirectUriStr = this.expandRedirectUri(request, clientRegistration);
Map<String,Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2Parameter.REGISTRATION_ID, clientRegistration.getRegistrationId());
AuthorizationRequestAttributes authorizationRequestAttributes =
AuthorizationRequestAttributes.withAuthorizationCode()
.clientId(clientRegistration.getClientId())
@ -129,6 +133,7 @@ public class AuthorizationCodeRequestRedirectFilter extends OncePerRequestFilter
.redirectUri(redirectUriStr)
.scope(clientRegistration.getScope())
.state(this.stateGenerator.generateKey())
.additionalParameters(additionalParameters)
.build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response);

View File

@ -38,9 +38,8 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import java.util.HashMap;
import java.util.Map;
/**
* Tests {@link AuthorizationCodeAuthenticationProcessingFilter}.
@ -233,6 +232,9 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
ClientRegistration clientRegistration,
String state) {
Map<String,Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2Parameter.REGISTRATION_ID, clientRegistration.getRegistrationId());
AuthorizationRequestAttributes authorizationRequestAttributes =
AuthorizationRequestAttributes.withAuthorizationCode()
.clientId(clientRegistration.getClientId())
@ -240,6 +242,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
.redirectUri(clientRegistration.getRedirectUri())
.scope(clientRegistration.getScope())
.state(state)
.additionalParameters(additionalParameters)
.build();
authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response);

View File

@ -21,7 +21,9 @@ import org.springframework.util.CollectionUtils;
import java.io.Serializable;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
/**
@ -43,6 +45,7 @@ public final class AuthorizationRequestAttributes implements Serializable {
private String redirectUri;
private Set<String> scope;
private String state;
private Map<String,Object> additionalParameters;
private AuthorizationRequestAttributes() {
}
@ -75,6 +78,10 @@ public final class AuthorizationRequestAttributes implements Serializable {
return this.state;
}
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}
public static Builder withAuthorizationCode() {
return new Builder(AuthorizationGrantType.AUTHORIZATION_CODE);
}
@ -107,8 +114,7 @@ public final class AuthorizationRequestAttributes implements Serializable {
}
public Builder scope(Set<String> scope) {
this.authorizationRequest.scope = Collections.unmodifiableSet(
CollectionUtils.isEmpty(scope) ? Collections.emptySet() : new LinkedHashSet<>(scope));
this.authorizationRequest.scope = scope;
return this;
}
@ -117,9 +123,20 @@ public final class AuthorizationRequestAttributes implements Serializable {
return this;
}
public Builder additionalParameters(Map<String,Object> additionalParameters) {
this.authorizationRequest.additionalParameters = additionalParameters;
return this;
}
public AuthorizationRequestAttributes build() {
Assert.hasText(this.authorizationRequest.clientId, "clientId cannot be empty");
Assert.hasText(this.authorizationRequest.authorizeUri, "authorizeUri cannot be empty");
this.authorizationRequest.scope = Collections.unmodifiableSet(
CollectionUtils.isEmpty(this.authorizationRequest.scope) ?
Collections.emptySet() : new LinkedHashSet<>(this.authorizationRequest.scope));
this.authorizationRequest.additionalParameters = Collections.unmodifiableMap(
CollectionUtils.isEmpty(this.authorizationRequest.additionalParameters) ?
Collections.emptyMap() : new LinkedHashMap<>(this.authorizationRequest.additionalParameters));
return this.authorizationRequest;
}
}

View File

@ -16,7 +16,7 @@
package org.springframework.security.oauth2.core.endpoint;
/**
* Standard parameters defined in the OAuth Parameters Registry
* Standard and additional (custom) parameters defined in the OAuth Parameters Registry
* and used by the authorization endpoint and token endpoint.
*
* @author Joe Grandja
@ -43,4 +43,6 @@ public interface OAuth2Parameter {
String ERROR_URI = "error_uri";
String REGISTRATION_ID = "registration_id"; // Non-standard additional parameter
}