From bf41d487180784abb38f51241eaa10fa8ee1a191 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 20 Mar 2018 20:45:58 -0500 Subject: [PATCH] HttpSessionOAuth2AuthorizationRequestRepository support distributed HttpSession Previously HttpSessionOAuth2AuthorizationRequestRepository getAuthorizationRequest attempted to update the state of HttpSession as well as getting the Map of OAuth2AuthorizationRequest. This had a few problems - First it was confusing that a get method updated state - It worked when the session was in memory, but would not work when the HttpSesson was persisted to an external store (i.e. Spring Session) since after updating the Map, there was no invocation to update This commit cleans up the logic and ensures that the values are explicitly set in the HttpSession so it works with a session persisted in an external store. Fixes: gh-5146 --- ...nOAuth2AuthorizationRequestRepository.java | 61 +++++++++++-------- ...h2AuthorizationRequestRepositoryTests.java | 59 ++++++++++++++++++ 2 files changed, 95 insertions(+), 25 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java index 316eb841ab..4d28716fb9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java @@ -30,6 +30,7 @@ import java.util.Map; * {@link OAuth2AuthorizationRequest} in the {@code HttpSession}. * * @author Joe Grandja + * @author Rob Winch * @since 5.0 * @see AuthorizationRequestRepository * @see OAuth2AuthorizationRequest @@ -37,17 +38,16 @@ import java.util.Map; public final class HttpSessionOAuth2AuthorizationRequestRepository implements AuthorizationRequestRepository { private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME = HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST"; + private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; @Override public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); - Assert.hasText(request.getParameter(OAuth2ParameterNames.STATE), "state parameter cannot be empty"); + String stateParameter = getStateParameter(request); + Assert.hasText(stateParameter, "state parameter cannot be empty"); Map authorizationRequests = this.getAuthorizationRequests(request); - if (authorizationRequests != null) { - return authorizationRequests.get(request.getParameter(OAuth2ParameterNames.STATE)); - } - return null; + return authorizationRequests.get(stateParameter); } @Override @@ -59,35 +59,46 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au this.removeAuthorizationRequest(request); return; } - Assert.hasText(authorizationRequest.getState(), "authorizationRequest.state cannot be empty"); - Map authorizationRequests = this.getAuthorizationRequests(request, true); - authorizationRequests.put(authorizationRequest.getState(), authorizationRequest); + String state = authorizationRequest.getState(); + Assert.hasText(state, "authorizationRequest.state cannot be empty"); + Map authorizationRequests = this.getAuthorizationRequests(request); + authorizationRequests.put(state, authorizationRequest); + request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); } @Override public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); - OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request); - if (authorizationRequest != null) { - Map authorizationRequests = this.getAuthorizationRequests(request); - authorizationRequests.remove(authorizationRequest.getState()); + String stateParameter = getStateParameter(request); + if (stateParameter == null) { + return null; } - return authorizationRequest; + Map authorizationRequests = this.getAuthorizationRequests(request); + OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter); + request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); + return originalRequest; } + /** + * Gets the state parameter from the {@link HttpServletRequest} + * @param request the request to use + * @return the state parameter or null if not found + */ + private String getStateParameter(HttpServletRequest request) { + return request.getParameter(OAuth2ParameterNames.STATE); + } + + /** + * Gets a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest} + * @param request + * @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}. + */ private Map getAuthorizationRequests(HttpServletRequest request) { - return this.getAuthorizationRequests(request, false); - } - - private Map getAuthorizationRequests(HttpServletRequest request, boolean createSession) { - Map authorizationRequests = null; - HttpSession session = request.getSession(createSession); - if (session != null) { - authorizationRequests = (Map) session.getAttribute(this.sessionAttributeName); - if (authorizationRequests == null) { - authorizationRequests = new HashMap<>(); - session.setAttribute(this.sessionAttributeName, authorizationRequests); - } + HttpSession session = request.getSession(false); + Map authorizationRequests = session == null ? null : + (Map) session.getAttribute(this.sessionAttributeName); + if (authorizationRequests == null) { + return new HashMap<>(); } return authorizationRequests; } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java index 4fc7bcf472..6f360ddcaa 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java @@ -23,9 +23,13 @@ import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import java.util.HashMap; +import java.util.Map; + /** * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}. * @@ -157,6 +161,42 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest); } + @Test + public void saveAuthorizationRequestWhenNoExistingSessionAndDistributedSessionThenSaved() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(new MockDistributedHttpSession()); + + OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build(); + this.authorizationRequestRepository.saveAuthorizationRequest( + authorizationRequest, request, new MockHttpServletResponse()); + + request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState()); + OAuth2AuthorizationRequest loadedAuthorizationRequest = + this.authorizationRequestRepository.loadAuthorizationRequest(request); + + assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest); + } + + @Test + public void saveAuthorizationRequestWhenExistingSessionAndDistributedSessionThenSaved() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(new MockDistributedHttpSession()); + + OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().build(); + this.authorizationRequestRepository.saveAuthorizationRequest( + authorizationRequest1, request, new MockHttpServletResponse()); + + OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().build(); + this.authorizationRequestRepository.saveAuthorizationRequest( + authorizationRequest2, request, new MockHttpServletResponse()); + + request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest2.getState()); + OAuth2AuthorizationRequest loadedAuthorizationRequest = + this.authorizationRequestRepository.loadAuthorizationRequest(request); + + assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest2); + } + @Test public void saveAuthorizationRequestWhenNullThenRemoved() { MockHttpServletRequest request = new MockHttpServletRequest(); @@ -220,4 +260,23 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { .clientId("client-id-1234") .state("state-1234"); } + + static class MockDistributedHttpSession extends MockHttpSession { + @Override + public Object getAttribute(String name) { + return wrap(super.getAttribute(name)); + } + + @Override + public void setAttribute(String name, Object value) { + super.setAttribute(name, wrap(value)); + } + + private Object wrap(Object object) { + if (object instanceof Map) { + object = new HashMap<>((Map) object); + } + return object; + } + } }