diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java index 4d28716fb9..2cc85b33cd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java @@ -44,8 +44,10 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au @Override public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); - String stateParameter = getStateParameter(request); - Assert.hasText(stateParameter, "state parameter cannot be empty"); + String stateParameter = this.getStateParameter(request); + if (stateParameter == null) { + return null; + } Map authorizationRequests = this.getAuthorizationRequests(request); return authorizationRequests.get(stateParameter); } @@ -69,7 +71,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au @Override public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); - String stateParameter = getStateParameter(request); + String stateParameter = this.getStateParameter(request); if (stateParameter == null) { return null; } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java index 6f360ddcaa..081831eba8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java @@ -15,9 +15,6 @@ */ package org.springframework.security.oauth2.client.web; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; @@ -30,6 +27,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import java.util.HashMap; import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + /** * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}. * @@ -107,15 +107,14 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { } @Test - public void loadAuthorizationRequestWhenSavedAndStateParameterNullThenThrowIllegalArgumentException() { + public void loadAuthorizationRequestWhenSavedAndStateParameterNullThenReturnNull() { MockHttpServletRequest request = new MockHttpServletRequest(); OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); this.authorizationRequestRepository.saveAuthorizationRequest( authorizationRequest, request, new MockHttpServletResponse()); - assertThatThrownBy(() -> this.authorizationRequestRepository.loadAuthorizationRequest(request)) - .isInstanceOf(IllegalArgumentException.class); + assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull(); } @Test