HttpSessionOAuth2AuthorizationRequestRepository support distributed HttpSession

Previously HttpSessionOAuth2AuthorizationRequestRepository
getAuthorizationRequest attempted to update the state of HttpSession as
well as getting the Map of OAuth2AuthorizationRequest. This had a few
problems

- First it was confusing that a get method updated state
- It worked when the session was in memory, but would not work when the
  HttpSesson was persisted to an external store (i.e. Spring Session) since
  after updating the Map, there was no invocation to update

This commit cleans up the logic and ensures that the values are explicitly
set in the HttpSession so it works with a session persisted in an external
store.

Fixes: gh-5146
This commit is contained in:
Rob Winch 2018-03-20 20:45:58 -05:00
parent 04e2e86e6e
commit bf41d48718
2 changed files with 95 additions and 25 deletions

View File

@ -30,6 +30,7 @@ import java.util.Map;
* {@link OAuth2AuthorizationRequest} in the {@code HttpSession}.
*
* @author Joe Grandja
* @author Rob Winch
* @since 5.0
* @see AuthorizationRequestRepository
* @see OAuth2AuthorizationRequest
@ -37,17 +38,16 @@ import java.util.Map;
public final class HttpSessionOAuth2AuthorizationRequestRepository implements AuthorizationRequestRepository<OAuth2AuthorizationRequest> {
private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME =
HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST";
private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
@Override
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
Assert.hasText(request.getParameter(OAuth2ParameterNames.STATE), "state parameter cannot be empty");
String stateParameter = getStateParameter(request);
Assert.hasText(stateParameter, "state parameter cannot be empty");
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
if (authorizationRequests != null) {
return authorizationRequests.get(request.getParameter(OAuth2ParameterNames.STATE));
}
return null;
return authorizationRequests.get(stateParameter);
}
@Override
@ -59,35 +59,46 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
this.removeAuthorizationRequest(request);
return;
}
Assert.hasText(authorizationRequest.getState(), "authorizationRequest.state cannot be empty");
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request, true);
authorizationRequests.put(authorizationRequest.getState(), authorizationRequest);
String state = authorizationRequest.getState();
Assert.hasText(state, "authorizationRequest.state cannot be empty");
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
authorizationRequests.put(state, authorizationRequest);
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
}
@Override
public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
if (authorizationRequest != null) {
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
authorizationRequests.remove(authorizationRequest.getState());
String stateParameter = getStateParameter(request);
if (stateParameter == null) {
return null;
}
return authorizationRequest;
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter);
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
return originalRequest;
}
/**
* Gets the state parameter from the {@link HttpServletRequest}
* @param request the request to use
* @return the state parameter or null if not found
*/
private String getStateParameter(HttpServletRequest request) {
return request.getParameter(OAuth2ParameterNames.STATE);
}
/**
* Gets a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}
* @param request
* @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}.
*/
private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
return this.getAuthorizationRequests(request, false);
}
private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request, boolean createSession) {
Map<String, OAuth2AuthorizationRequest> authorizationRequests = null;
HttpSession session = request.getSession(createSession);
if (session != null) {
authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
if (authorizationRequests == null) {
authorizationRequests = new HashMap<>();
session.setAttribute(this.sessionAttributeName, authorizationRequests);
}
HttpSession session = request.getSession(false);
Map<String, OAuth2AuthorizationRequest> authorizationRequests = session == null ? null :
(Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
if (authorizationRequests == null) {
return new HashMap<>();
}
return authorizationRequests;
}

View File

@ -23,9 +23,13 @@ import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import java.util.HashMap;
import java.util.Map;
/**
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
*
@ -157,6 +161,42 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
}
@Test
public void saveAuthorizationRequestWhenNoExistingSessionAndDistributedSessionThenSaved() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setSession(new MockDistributedHttpSession());
OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build();
this.authorizationRequestRepository.saveAuthorizationRequest(
authorizationRequest, request, new MockHttpServletResponse());
request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState());
OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
}
@Test
public void saveAuthorizationRequestWhenExistingSessionAndDistributedSessionThenSaved() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setSession(new MockDistributedHttpSession());
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().build();
this.authorizationRequestRepository.saveAuthorizationRequest(
authorizationRequest1, request, new MockHttpServletResponse());
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().build();
this.authorizationRequestRepository.saveAuthorizationRequest(
authorizationRequest2, request, new MockHttpServletResponse());
request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest2.getState());
OAuth2AuthorizationRequest loadedAuthorizationRequest =
this.authorizationRequestRepository.loadAuthorizationRequest(request);
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest2);
}
@Test
public void saveAuthorizationRequestWhenNullThenRemoved() {
MockHttpServletRequest request = new MockHttpServletRequest();
@ -220,4 +260,23 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
.clientId("client-id-1234")
.state("state-1234");
}
static class MockDistributedHttpSession extends MockHttpSession {
@Override
public Object getAttribute(String name) {
return wrap(super.getAttribute(name));
}
@Override
public void setAttribute(String name, Object value) {
super.setAttribute(name, wrap(value));
}
private Object wrap(Object object) {
if (object instanceof Map) {
object = new HashMap<>((Map<Object, Object>) object);
}
return object;
}
}
}