mirror of
https://github.com/spring-projects/spring-security.git
synced 2025-12-26 10:04:08 +00:00
OAuth2AuthorizationEndpointFilter is applied after AuthorizationFilter
Closes gh-18251
This commit is contained in:
parent
244b5a16be
commit
c53e66a217
@ -16,10 +16,12 @@
|
||||
|
||||
package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import jakarta.servlet.Filter;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.springframework.http.HttpMethod;
|
||||
@ -36,10 +38,12 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationValidator;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
||||
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
|
||||
import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
|
||||
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
|
||||
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
|
||||
import org.springframework.security.web.access.intercept.AuthorizationFilter;
|
||||
import org.springframework.security.web.authentication.AuthenticationConverter;
|
||||
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
||||
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
||||
@ -50,6 +54,7 @@ import org.springframework.security.web.servlet.util.matcher.PathPatternRequestM
|
||||
import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ReflectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
@ -83,6 +88,8 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
|
||||
|
||||
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidator;
|
||||
|
||||
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidatorComposite;
|
||||
|
||||
private SessionAuthenticationStrategy sessionAuthenticationStrategy;
|
||||
|
||||
/**
|
||||
@ -248,8 +255,16 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
|
||||
authenticationProviders.addAll(0, this.authenticationProviders);
|
||||
}
|
||||
this.authenticationProvidersConsumer.accept(authenticationProviders);
|
||||
authenticationProviders.forEach(
|
||||
(authenticationProvider) -> httpSecurity.authenticationProvider(postProcess(authenticationProvider)));
|
||||
authenticationProviders.forEach((authenticationProvider) -> {
|
||||
httpSecurity.authenticationProvider(postProcess(authenticationProvider));
|
||||
if (authenticationProvider instanceof OAuth2AuthorizationCodeRequestAuthenticationProvider) {
|
||||
Method method = ReflectionUtils.findMethod(OAuth2AuthorizationCodeRequestAuthenticationProvider.class,
|
||||
"getAuthenticationValidatorComposite");
|
||||
ReflectionUtils.makeAccessible(method);
|
||||
this.authorizationCodeRequestAuthenticationValidatorComposite = (Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext>) ReflectionUtils
|
||||
.invokeMethod(method, authenticationProvider);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -282,7 +297,18 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
|
||||
if (this.sessionAuthenticationStrategy != null) {
|
||||
authorizationEndpointFilter.setSessionAuthenticationStrategy(this.sessionAuthenticationStrategy);
|
||||
}
|
||||
httpSecurity.addFilterBefore(postProcess(authorizationEndpointFilter),
|
||||
httpSecurity.addFilterAfter(postProcess(authorizationEndpointFilter), AuthorizationFilter.class);
|
||||
// Create and add
|
||||
// OAuth2AuthorizationEndpointFilter.OAuth2AuthorizationCodeRequestValidatingFilter
|
||||
Method method = ReflectionUtils.findMethod(OAuth2AuthorizationEndpointFilter.class,
|
||||
"createAuthorizationCodeRequestValidatingFilter", RegisteredClientRepository.class, Consumer.class);
|
||||
ReflectionUtils.makeAccessible(method);
|
||||
RegisteredClientRepository registeredClientRepository = OAuth2ConfigurerUtils
|
||||
.getRegisteredClientRepository(httpSecurity);
|
||||
Filter authorizationCodeRequestValidatingFilter = (Filter) ReflectionUtils.invokeMethod(method,
|
||||
authorizationEndpointFilter, registeredClientRepository,
|
||||
this.authorizationCodeRequestAuthenticationValidatorComposite);
|
||||
httpSecurity.addFilterBefore(postProcess(authorizationCodeRequestValidatingFilter),
|
||||
AbstractPreAuthenticatedProcessingFilter.class);
|
||||
}
|
||||
|
||||
|
||||
@ -307,8 +307,8 @@ public class OAuth2AuthorizationCodeGrantTests {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
|
||||
this.mvc
|
||||
.perform(
|
||||
get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient)))
|
||||
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
|
||||
.queryParams(getAuthorizationRequestParameters(registeredClient)))
|
||||
.andExpect(status().isBadRequest())
|
||||
.andReturn();
|
||||
}
|
||||
@ -851,21 +851,31 @@ public class OAuth2AuthorizationCodeGrantTests {
|
||||
this.spring.register(AuthorizationServerConfigurationCustomAuthorizationEndpoint.class).autowire();
|
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
this.registeredClientRepository.save(registeredClient);
|
||||
|
||||
TestingAuthenticationToken principal = new TestingAuthenticationToken("principalName", "password");
|
||||
Map<String, Object> additionalParameters = new HashMap<>();
|
||||
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE);
|
||||
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal,
|
||||
registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes(),
|
||||
additionalParameters);
|
||||
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode("code", Instant.now(),
|
||||
Instant.now().plus(5, ChronoUnit.MINUTES));
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode,
|
||||
registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED,
|
||||
registeredClient.getScopes());
|
||||
given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthenticationResult);
|
||||
given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication);
|
||||
given(authorizationRequestAuthenticationProvider
|
||||
.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).willReturn(true);
|
||||
given(authorizationRequestAuthenticationProvider.authenticate(any()))
|
||||
.willReturn(authorizationCodeRequestAuthenticationResult);
|
||||
|
||||
this.mvc
|
||||
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient))
|
||||
.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
|
||||
.queryParams(getAuthorizationRequestParameters(registeredClient))
|
||||
.with(user("user")))
|
||||
.andExpect(status().isOk());
|
||||
|
||||
@ -880,8 +890,7 @@ public class OAuth2AuthorizationCodeGrantTests {
|
||||
|| converter instanceof OAuth2AuthorizationCodeRequestAuthenticationConverter
|
||||
|| converter instanceof OAuth2AuthorizationConsentAuthenticationConverter);
|
||||
|
||||
verify(authorizationRequestAuthenticationProvider)
|
||||
.authenticate(eq(authorizationCodeRequestAuthenticationResult));
|
||||
verify(authorizationRequestAuthenticationProvider).authenticate(eq(authorizationCodeRequestAuthentication));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor = ArgumentCaptor
|
||||
|
||||
@ -190,51 +190,55 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
|
||||
OAuth2AuthorizationCodeRequestAuthenticationContext.Builder authenticationContextBuilder = OAuth2AuthorizationCodeRequestAuthenticationContext
|
||||
.with(authorizationCodeRequestAuthentication)
|
||||
.registeredClient(registeredClient);
|
||||
OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = authenticationContextBuilder
|
||||
.build();
|
||||
|
||||
// grant_type
|
||||
OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR
|
||||
.accept(authenticationContext);
|
||||
if (!authorizationCodeRequestAuthentication.isValidated()) {
|
||||
OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = authenticationContextBuilder
|
||||
.build();
|
||||
|
||||
// redirect_uri and scope
|
||||
this.authenticationValidator.accept(authenticationContext);
|
||||
// grant_type
|
||||
OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR
|
||||
.accept(authenticationContext);
|
||||
|
||||
// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
|
||||
OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR
|
||||
.accept(authenticationContext);
|
||||
// redirect_uri and scope
|
||||
this.authenticationValidator.accept(authenticationContext);
|
||||
|
||||
// prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request)
|
||||
Set<String> promptValues = Collections.emptySet();
|
||||
if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) {
|
||||
String prompt = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get("prompt");
|
||||
if (StringUtils.hasText(prompt)) {
|
||||
OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR
|
||||
.accept(authenticationContext);
|
||||
promptValues = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(prompt, " ")));
|
||||
// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
|
||||
OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR
|
||||
.accept(authenticationContext);
|
||||
|
||||
// prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request)
|
||||
OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR
|
||||
.accept(authenticationContext);
|
||||
|
||||
authorizationCodeRequestAuthentication.setValidated(true);
|
||||
|
||||
if (this.logger.isTraceEnabled()) {
|
||||
this.logger.trace("Validated authorization code request parameters");
|
||||
}
|
||||
}
|
||||
|
||||
if (this.logger.isTraceEnabled()) {
|
||||
this.logger.trace("Validated authorization code request parameters");
|
||||
}
|
||||
|
||||
// ---------------
|
||||
// The request is valid - ensure the resource owner is authenticated
|
||||
// ---------------
|
||||
|
||||
Authentication principal = (Authentication) authorizationCodeRequestAuthentication.getPrincipal();
|
||||
|
||||
Set<String> promptValues = Collections.emptySet();
|
||||
if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) {
|
||||
String prompt = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get("prompt");
|
||||
if (StringUtils.hasText(prompt)) {
|
||||
promptValues = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(prompt, " ")));
|
||||
}
|
||||
}
|
||||
|
||||
if (!isPrincipalAuthenticated(principal)) {
|
||||
if (promptValues.contains(OidcPrompt.NONE)) {
|
||||
// Return an error instead of displaying the login page (via the
|
||||
// configured AuthenticationEntryPoint)
|
||||
throwError("login_required", "prompt", authorizationCodeRequestAuthentication, registeredClient);
|
||||
}
|
||||
if (this.logger.isTraceEnabled()) {
|
||||
this.logger.trace("Did not authenticate authorization code request since principal not authenticated");
|
||||
else {
|
||||
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "principal", authorizationCodeRequestAuthentication,
|
||||
registeredClient);
|
||||
}
|
||||
// Return the authorization request as-is where isAuthenticated() is false
|
||||
return authorizationCodeRequestAuthentication;
|
||||
}
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
|
||||
@ -400,6 +404,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
|
||||
this.authorizationConsentRequired = authorizationConsentRequired;
|
||||
}
|
||||
|
||||
Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> getAuthenticationValidatorComposite() {
|
||||
return OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR
|
||||
.andThen(this.authenticationValidator)
|
||||
.andThen(OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR)
|
||||
.andThen(OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR);
|
||||
}
|
||||
|
||||
private static boolean isAuthorizationConsentRequired(
|
||||
OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext) {
|
||||
if (!authenticationContext.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()) {
|
||||
|
||||
@ -42,6 +42,8 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
|
||||
|
||||
private final OAuth2AuthorizationCode authorizationCode;
|
||||
|
||||
private boolean validated;
|
||||
|
||||
/**
|
||||
* Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationToken} using the
|
||||
* provided parameters.
|
||||
@ -89,4 +91,12 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
|
||||
return this.authorizationCode;
|
||||
}
|
||||
|
||||
final boolean isValidated() {
|
||||
return this.validated;
|
||||
}
|
||||
|
||||
final void setValidated(boolean validated) {
|
||||
this.validated = validated;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -17,11 +17,14 @@
|
||||
package org.springframework.security.oauth2.server.authorization.web;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Set;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import jakarta.servlet.Filter;
|
||||
import jakarta.servlet.FilterChain;
|
||||
import jakarta.servlet.ServletException;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
@ -38,14 +41,18 @@ import org.springframework.security.core.AuthenticationException;
|
||||
import org.springframework.security.core.session.SessionRegistry;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationContext;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
|
||||
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
||||
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
|
||||
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
|
||||
import org.springframework.security.web.DefaultRedirectStrategy;
|
||||
@ -64,6 +71,7 @@ import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ReflectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.filter.OncePerRequestFilter;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
@ -180,21 +188,18 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
|
||||
}
|
||||
|
||||
try {
|
||||
Authentication authentication = this.authenticationConverter.convert(request);
|
||||
if (authentication instanceof AbstractAuthenticationToken authenticationToken) {
|
||||
authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request));
|
||||
// Get the pre-validated authorization code request (if available),
|
||||
// which was set by OAuth2AuthorizationCodeRequestValidatingFilter
|
||||
Authentication authentication = (Authentication) request
|
||||
.getAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName());
|
||||
if (authentication == null) {
|
||||
authentication = this.authenticationConverter.convert(request);
|
||||
if (authentication instanceof AbstractAuthenticationToken authenticationToken) {
|
||||
authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request));
|
||||
}
|
||||
}
|
||||
Authentication authenticationResult = this.authenticationManager.authenticate(authentication);
|
||||
|
||||
if (!authenticationResult.isAuthenticated()) {
|
||||
// If the Principal (Resource Owner) is not authenticated then pass
|
||||
// through the chain
|
||||
// with the expectation that the authentication process will commence via
|
||||
// AuthenticationEntryPoint
|
||||
filterChain.doFilter(request, response);
|
||||
return;
|
||||
}
|
||||
|
||||
if (authenticationResult instanceof OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationToken) {
|
||||
if (this.logger.isTraceEnabled()) {
|
||||
this.logger.trace("Authorization consent is required");
|
||||
@ -401,4 +406,109 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
|
||||
this.redirectStrategy.sendRedirect(request, response, redirectUri);
|
||||
}
|
||||
|
||||
Filter createAuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository,
|
||||
Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator) {
|
||||
return new OAuth2AuthorizationCodeRequestValidatingFilter(registeredClientRepository, authenticationValidator);
|
||||
}
|
||||
|
||||
/**
|
||||
* A {@code Filter} that is applied before {@code OAuth2AuthorizationEndpointFilter}
|
||||
* and handles the pre-validation of an OAuth 2.0 Authorization Code Request.
|
||||
*/
|
||||
private final class OAuth2AuthorizationCodeRequestValidatingFilter extends OncePerRequestFilter {
|
||||
|
||||
private final RegisteredClientRepository registeredClientRepository;
|
||||
|
||||
private final Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator;
|
||||
|
||||
private final Field setValidatedField;
|
||||
|
||||
private OAuth2AuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository,
|
||||
Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator) {
|
||||
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
|
||||
Assert.notNull(authenticationValidator, "authenticationValidator cannot be null");
|
||||
this.registeredClientRepository = registeredClientRepository;
|
||||
this.authenticationValidator = authenticationValidator;
|
||||
this.setValidatedField = ReflectionUtils.findField(OAuth2AuthorizationCodeRequestAuthenticationToken.class,
|
||||
"validated");
|
||||
ReflectionUtils.makeAccessible(this.setValidatedField);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
|
||||
FilterChain filterChain) throws ServletException, IOException {
|
||||
|
||||
if (!OAuth2AuthorizationEndpointFilter.this.authorizationEndpointMatcher.matches(request)) {
|
||||
filterChain.doFilter(request, response);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
Authentication authentication = OAuth2AuthorizationEndpointFilter.this.authenticationConverter
|
||||
.convert(request);
|
||||
if (!(authentication instanceof OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication)) {
|
||||
filterChain.doFilter(request, response);
|
||||
return;
|
||||
}
|
||||
|
||||
String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters()
|
||||
.get(OAuth2ParameterNames.REQUEST_URI);
|
||||
if (StringUtils.hasText(requestUri)) {
|
||||
filterChain.doFilter(request, response);
|
||||
return;
|
||||
}
|
||||
|
||||
authorizationCodeRequestAuthentication.setDetails(
|
||||
OAuth2AuthorizationEndpointFilter.this.authenticationDetailsSource.buildDetails(request));
|
||||
|
||||
RegisteredClient registeredClient = this.registeredClientRepository
|
||||
.findByClientId(authorizationCodeRequestAuthentication.getClientId());
|
||||
if (registeredClient == null) {
|
||||
String redirectUri = null; // Prevent redirect
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
authorizationCodeRequestAuthentication.getAuthorizationUri(),
|
||||
authorizationCodeRequestAuthentication.getClientId(),
|
||||
(Authentication) authorizationCodeRequestAuthentication.getPrincipal(), redirectUri,
|
||||
authorizationCodeRequestAuthentication.getState(),
|
||||
authorizationCodeRequestAuthentication.getScopes(),
|
||||
authorizationCodeRequestAuthentication.getAdditionalParameters());
|
||||
|
||||
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST,
|
||||
"OAuth 2.0 Parameter: " + OAuth2ParameterNames.CLIENT_ID,
|
||||
"https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1");
|
||||
throw new OAuth2AuthorizationCodeRequestAuthenticationException(error,
|
||||
authorizationCodeRequestAuthenticationResult);
|
||||
}
|
||||
|
||||
OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = OAuth2AuthorizationCodeRequestAuthenticationContext
|
||||
.with(authorizationCodeRequestAuthentication)
|
||||
.registeredClient(registeredClient)
|
||||
.build();
|
||||
|
||||
this.authenticationValidator.accept(authenticationContext);
|
||||
|
||||
ReflectionUtils.setField(this.setValidatedField, authorizationCodeRequestAuthentication, true);
|
||||
|
||||
// Set the validated authorization code request as a request
|
||||
// attribute
|
||||
// to be used upstream by OAuth2AuthorizationEndpointFilter
|
||||
request.setAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName(),
|
||||
authorizationCodeRequestAuthentication);
|
||||
|
||||
filterChain.doFilter(request, response);
|
||||
}
|
||||
catch (OAuth2AuthenticationException ex) {
|
||||
if (this.logger.isTraceEnabled()) {
|
||||
this.logger.trace(LogMessage.format("Authorization request failed: %s", ex.getError()), ex);
|
||||
}
|
||||
OAuth2AuthorizationEndpointFilter.this.authenticationFailureHandler.onAuthenticationFailure(request,
|
||||
response, ex);
|
||||
}
|
||||
finally {
|
||||
request.removeAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -428,7 +428,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authenticateWhenPrincipalNotAuthenticatedThenReturnAuthorizationCodeRequest() {
|
||||
public void authenticateWhenPrincipalNotAuthenticatedThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
given(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
|
||||
.willReturn(registeredClient);
|
||||
@ -438,12 +438,10 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, redirectUri, STATE,
|
||||
registeredClient.getScopes(), createPkceParameters());
|
||||
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider
|
||||
.authenticate(authentication);
|
||||
|
||||
assertThat(authenticationResult).isSameAs(authentication);
|
||||
assertThat(authenticationResult.isAuthenticated()).isFalse();
|
||||
assertThatExceptionOfType(OAuth2AuthorizationCodeRequestAuthenticationException.class)
|
||||
.isThrownBy(() -> this.authenticationProvider.authenticate(authentication))
|
||||
.satisfies((ex) -> assertAuthenticationException(ex, OAuth2ErrorCodes.INVALID_REQUEST, "principal",
|
||||
authentication.getRedirectUri()));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@ -372,7 +372,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
given(authenticationConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication);
|
||||
this.filter.setAuthenticationConverter(authenticationConverter);
|
||||
|
||||
given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthentication);
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, this.authorizationCode,
|
||||
registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes());
|
||||
authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
|
||||
given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
@ -382,7 +386,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
|
||||
verify(authenticationConverter).convert(any());
|
||||
verify(this.authenticationManager).authenticate(any());
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
verifyNoInteractions(filterChain);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -461,9 +465,6 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
@Test
|
||||
public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal,
|
||||
registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null);
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
|
||||
AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource = mock(
|
||||
@ -472,7 +473,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
given(authenticationDetailsSource.buildDetails(request)).willReturn(webAuthenticationDetails);
|
||||
this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);
|
||||
|
||||
given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthentication);
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, this.authorizationCode,
|
||||
registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes());
|
||||
authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
|
||||
given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult);
|
||||
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
@ -481,27 +486,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
|
||||
verify(authenticationDetailsSource).buildDetails(any());
|
||||
verify(this.authenticationManager).authenticate(any());
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
|
||||
this.principal.setAuthenticated(false);
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
||||
AUTHORIZATION_URI, registeredClient.getClientId(), this.principal,
|
||||
registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null);
|
||||
authorizationCodeRequestAuthenticationResult.setAuthenticated(false);
|
||||
given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult);
|
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authenticationManager).authenticate(any());
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
verifyNoInteractions(filterChain);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user