Add HttpServletResponse param to removeAuthorizationRequest

Fixes gh-5313
This commit is contained in:
Joe Grandja 2018-07-24 09:43:26 -04:00
parent 887db71333
commit 2c1c2c78c3
5 changed files with 38 additions and 8 deletions

View File

@ -63,9 +63,22 @@ public interface AuthorizationRequestRepository<T extends OAuth2AuthorizationReq
* Removes and returns the {@link OAuth2AuthorizationRequest} associated to the
* provided {@code HttpServletRequest} or if not available returns {@code null}.
*
* @deprecated Use {@link #removeAuthorizationRequest(HttpServletRequest, HttpServletResponse)} instead
* @param request the {@code HttpServletRequest}
* @return the removed {@link OAuth2AuthorizationRequest} or {@code null} if not available
*/
T removeAuthorizationRequest(HttpServletRequest request);
/**
* Removes and returns the {@link OAuth2AuthorizationRequest} associated to the
* provided {@code HttpServletRequest} or if not available returns {@code null}.
*
* @since 5.1
* @param request the {@code HttpServletRequest}
* @param response the {@code HttpServletResponse}
* @return the {@link OAuth2AuthorizationRequest} or {@code null} if not available
*/
default T removeAuthorizationRequest(HttpServletRequest request, HttpServletResponse response) {
return removeAuthorizationRequest(request);
}
}

View File

@ -58,7 +58,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
Assert.notNull(request, "request cannot be null");
Assert.notNull(response, "response cannot be null");
if (authorizationRequest == null) {
this.removeAuthorizationRequest(request);
this.removeAuthorizationRequest(request, response);
return;
}
String state = authorizationRequest.getState();
@ -85,6 +85,12 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
return originalRequest;
}
@Override
public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request, HttpServletResponse response) {
Assert.notNull(response, "response cannot be null");
return this.removeAuthorizationRequest(request);
}
/**
* Gets the state parameter from the {@link HttpServletRequest}
* @param request the request to use

View File

@ -158,7 +158,8 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
private void processAuthorizationResponse(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.removeAuthorizationRequest(request);
OAuth2AuthorizationRequest authorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID);
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);

View File

@ -156,7 +156,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.removeAuthorizationRequest(request);
OAuth2AuthorizationRequest authorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
if (authorizationRequest == null) {
OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());

View File

@ -217,9 +217,16 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
assertThat(loadedAuthorizationRequest).isNull();
}
@Test(expected = IllegalArgumentException.class)
@Test
public void removeAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
this.authorizationRequestRepository.removeAuthorizationRequest(null);
assertThatThrownBy(() -> this.authorizationRequestRepository.removeAuthorizationRequest(
null, new MockHttpServletResponse())).isInstanceOf(IllegalArgumentException.class);
}
@Test
public void removeAuthorizationRequestWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizationRequestRepository.removeAuthorizationRequest(
new MockHttpServletRequest(), null)).isInstanceOf(IllegalArgumentException.class);
}
@Test
@ -234,7 +241,7 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState());
OAuth2AuthorizationRequest removedAuthorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request);
this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
@ -255,7 +262,7 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState());
OAuth2AuthorizationRequest removedAuthorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request);
this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
String sessionAttributeName = HttpSessionOAuth2AuthorizationRequestRepository.class.getName() +
".AUTHORIZATION_REQUEST";
@ -269,8 +276,10 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
MockHttpServletResponse response = new MockHttpServletResponse();
OAuth2AuthorizationRequest removedAuthorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request);
this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
assertThat(removedAuthorizationRequest).isNull();
}