diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java index 316eb841ab..4d28716fb9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java @@ -30,6 +30,7 @@ import java.util.Map; * {@link OAuth2AuthorizationRequest} in the {@code HttpSession}. * * @author Joe Grandja + * @author Rob Winch * @since 5.0 * @see AuthorizationRequestRepository * @see OAuth2AuthorizationRequest @@ -37,17 +38,16 @@ import java.util.Map; public final class HttpSessionOAuth2AuthorizationRequestRepository implements AuthorizationRequestRepository { private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME = HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST"; + private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; @Override public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); - Assert.hasText(request.getParameter(OAuth2ParameterNames.STATE), "state parameter cannot be empty"); + String stateParameter = getStateParameter(request); + Assert.hasText(stateParameter, "state parameter cannot be empty"); Map authorizationRequests = this.getAuthorizationRequests(request); - if (authorizationRequests != null) { - return authorizationRequests.get(request.getParameter(OAuth2ParameterNames.STATE)); - } - return null; + return authorizationRequests.get(stateParameter); } @Override @@ -59,35 +59,46 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au this.removeAuthorizationRequest(request); return; } - Assert.hasText(authorizationRequest.getState(), "authorizationRequest.state cannot be empty"); - Map authorizationRequests = this.getAuthorizationRequests(request, true); - authorizationRequests.put(authorizationRequest.getState(), authorizationRequest); + String state = authorizationRequest.getState(); + Assert.hasText(state, "authorizationRequest.state cannot be empty"); + Map authorizationRequests = this.getAuthorizationRequests(request); + authorizationRequests.put(state, authorizationRequest); + request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); } @Override public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); - OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request); - if (authorizationRequest != null) { - Map authorizationRequests = this.getAuthorizationRequests(request); - authorizationRequests.remove(authorizationRequest.getState()); + String stateParameter = getStateParameter(request); + if (stateParameter == null) { + return null; } - return authorizationRequest; + Map authorizationRequests = this.getAuthorizationRequests(request); + OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter); + request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); + return originalRequest; } + /** + * Gets the state parameter from the {@link HttpServletRequest} + * @param request the request to use + * @return the state parameter or null if not found + */ + private String getStateParameter(HttpServletRequest request) { + return request.getParameter(OAuth2ParameterNames.STATE); + } + + /** + * Gets a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest} + * @param request + * @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}. + */ private Map getAuthorizationRequests(HttpServletRequest request) { - return this.getAuthorizationRequests(request, false); - } - - private Map getAuthorizationRequests(HttpServletRequest request, boolean createSession) { - Map authorizationRequests = null; - HttpSession session = request.getSession(createSession); - if (session != null) { - authorizationRequests = (Map) session.getAttribute(this.sessionAttributeName); - if (authorizationRequests == null) { - authorizationRequests = new HashMap<>(); - session.setAttribute(this.sessionAttributeName, authorizationRequests); - } + HttpSession session = request.getSession(false); + Map authorizationRequests = session == null ? null : + (Map) session.getAttribute(this.sessionAttributeName); + if (authorizationRequests == null) { + return new HashMap<>(); } return authorizationRequests; } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java index 4fc7bcf472..6f360ddcaa 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java @@ -23,9 +23,13 @@ import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import java.util.HashMap; +import java.util.Map; + /** * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}. * @@ -157,6 +161,42 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest); } + @Test + public void saveAuthorizationRequestWhenNoExistingSessionAndDistributedSessionThenSaved() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(new MockDistributedHttpSession()); + + OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); + this.authorizationRequestRepository.saveAuthorizationRequest( + authorizationRequest, request, new MockHttpServletResponse()); + + request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState()); + OAuth2AuthorizationRequest loadedAuthorizationRequest = + this.authorizationRequestRepository.loadAuthorizationRequest(request); + + assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest); + } + + @Test + public void saveAuthorizationRequestWhenExistingSessionAndDistributedSessionThenSaved() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(new MockDistributedHttpSession()); + + OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().build(); + this.authorizationRequestRepository.saveAuthorizationRequest( + authorizationRequest1, request, new MockHttpServletResponse()); + + OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().build(); + this.authorizationRequestRepository.saveAuthorizationRequest( + authorizationRequest2, request, new MockHttpServletResponse()); + + request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest2.getState()); + OAuth2AuthorizationRequest loadedAuthorizationRequest = + this.authorizationRequestRepository.loadAuthorizationRequest(request); + + assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest2); + } + @Test public void saveAuthorizationRequestWhenNullThenRemoved() { MockHttpServletRequest request = new MockHttpServletRequest(); @@ -220,4 +260,23 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { .clientId("client-id-1234") .state("state-1234"); } + + static class MockDistributedHttpSession extends MockHttpSession { + @Override + public Object getAttribute(String name) { + return wrap(super.getAttribute(name)); + } + + @Override + public void setAttribute(String name, Object value) { + super.setAttribute(name, wrap(value)); + } + + private Object wrap(Object object) { + if (object instanceof Map) { + object = new HashMap<>((Map) object); + } + return object; + } + } }