diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java index fb39896d10..54187475b8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java @@ -96,6 +96,7 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce */ public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/oauth2/code/*"; private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; + private static final String CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE = "client_registration_not_found"; private ClientRegistrationRepository clientRegistrationRepository; private OAuth2AuthorizedClientService authorizedClientService; private AuthorizationRequestRepository authorizationRequestRepository = @@ -146,7 +147,11 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); - + if (clientRegistration == null) { + OAuth2Error oauth2Error = new OAuth2Error(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE, + "Client Registration not found with Id: " + registrationId, null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(request); OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken( diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java index 3c765fe6e4..2b788c8245 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java @@ -199,6 +199,45 @@ public class OAuth2LoginAuthenticationFilterTests { assertThat(authenticationException.getError().getErrorCode()).isEqualTo("authorization_request_not_found"); } + // gh-5251 + @Test + public void doFilterWhenAuthorizationResponseClientRegistrationNotFoundThenClientRegistrationNotFoundError() throws Exception { + String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId(); + String state = "state"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.addParameter(OAuth2ParameterNames.CODE, "code"); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + ClientRegistration registrationNotFound = ClientRegistration.withRegistrationId("registration-not-found") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}") + .scope("user") + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .userInfoUri("https://provider.com/oauth2/user") + .userNameAttributeName("id") + .clientName("client-1") + .build(); + this.setUpAuthorizationRequest(request, response, registrationNotFound, state); + + this.filter.doFilter(request, response, filterChain); + + ArgumentCaptor authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class); + verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class), + authenticationExceptionArgCaptor.capture()); + + assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class); + OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue(); + assertThat(authenticationException.getError().getErrorCode()).isEqualTo("client_registration_not_found"); + } + @Test public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception { String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();