HttpSessionOAuth2AuthorizationRequestRepository support distributed HttpSession
Previously HttpSessionOAuth2AuthorizationRequestRepository getAuthorizationRequest attempted to update the state of HttpSession as well as getting the Map of OAuth2AuthorizationRequest. This had a few problems - First it was confusing that a get method updated state - It worked when the session was in memory, but would not work when the HttpSesson was persisted to an external store (i.e. Spring Session) since after updating the Map, there was no invocation to update This commit cleans up the logic and ensures that the values are explicitly set in the HttpSession so it works with a session persisted in an external store. Fixes: gh-5146
This commit is contained in:
parent
04e2e86e6e
commit
bf41d48718
|
@ -30,6 +30,7 @@ import java.util.Map;
|
|||
* {@link OAuth2AuthorizationRequest} in the {@code HttpSession}.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
* @author Rob Winch
|
||||
* @since 5.0
|
||||
* @see AuthorizationRequestRepository
|
||||
* @see OAuth2AuthorizationRequest
|
||||
|
@ -37,17 +38,16 @@ import java.util.Map;
|
|||
public final class HttpSessionOAuth2AuthorizationRequestRepository implements AuthorizationRequestRepository<OAuth2AuthorizationRequest> {
|
||||
private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME =
|
||||
HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST";
|
||||
|
||||
private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
|
||||
|
||||
@Override
|
||||
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
|
||||
Assert.notNull(request, "request cannot be null");
|
||||
Assert.hasText(request.getParameter(OAuth2ParameterNames.STATE), "state parameter cannot be empty");
|
||||
String stateParameter = getStateParameter(request);
|
||||
Assert.hasText(stateParameter, "state parameter cannot be empty");
|
||||
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
||||
if (authorizationRequests != null) {
|
||||
return authorizationRequests.get(request.getParameter(OAuth2ParameterNames.STATE));
|
||||
}
|
||||
return null;
|
||||
return authorizationRequests.get(stateParameter);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -59,35 +59,46 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
|
|||
this.removeAuthorizationRequest(request);
|
||||
return;
|
||||
}
|
||||
Assert.hasText(authorizationRequest.getState(), "authorizationRequest.state cannot be empty");
|
||||
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request, true);
|
||||
authorizationRequests.put(authorizationRequest.getState(), authorizationRequest);
|
||||
String state = authorizationRequest.getState();
|
||||
Assert.hasText(state, "authorizationRequest.state cannot be empty");
|
||||
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
||||
authorizationRequests.put(state, authorizationRequest);
|
||||
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request) {
|
||||
Assert.notNull(request, "request cannot be null");
|
||||
OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
|
||||
if (authorizationRequest != null) {
|
||||
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
||||
authorizationRequests.remove(authorizationRequest.getState());
|
||||
String stateParameter = getStateParameter(request);
|
||||
if (stateParameter == null) {
|
||||
return null;
|
||||
}
|
||||
return authorizationRequest;
|
||||
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
||||
OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter);
|
||||
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
|
||||
return originalRequest;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the state parameter from the {@link HttpServletRequest}
|
||||
* @param request the request to use
|
||||
* @return the state parameter or null if not found
|
||||
*/
|
||||
private String getStateParameter(HttpServletRequest request) {
|
||||
return request.getParameter(OAuth2ParameterNames.STATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}
|
||||
* @param request
|
||||
* @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}.
|
||||
*/
|
||||
private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
|
||||
return this.getAuthorizationRequests(request, false);
|
||||
}
|
||||
|
||||
private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request, boolean createSession) {
|
||||
Map<String, OAuth2AuthorizationRequest> authorizationRequests = null;
|
||||
HttpSession session = request.getSession(createSession);
|
||||
if (session != null) {
|
||||
authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
|
||||
if (authorizationRequests == null) {
|
||||
authorizationRequests = new HashMap<>();
|
||||
session.setAttribute(this.sessionAttributeName, authorizationRequests);
|
||||
}
|
||||
HttpSession session = request.getSession(false);
|
||||
Map<String, OAuth2AuthorizationRequest> authorizationRequests = session == null ? null :
|
||||
(Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
|
||||
if (authorizationRequests == null) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
return authorizationRequests;
|
||||
}
|
||||
|
|
|
@ -23,9 +23,13 @@ import org.junit.runner.RunWith;
|
|||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockHttpServletResponse;
|
||||
import org.springframework.mock.web.MockHttpSession;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
|
||||
*
|
||||
|
@ -157,6 +161,42 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
|
|||
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveAuthorizationRequestWhenNoExistingSessionAndDistributedSessionThenSaved() {
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
request.setSession(new MockDistributedHttpSession());
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = createAuthorizationRequest().build();
|
||||
this.authorizationRequestRepository.saveAuthorizationRequest(
|
||||
authorizationRequest, request, new MockHttpServletResponse());
|
||||
|
||||
request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState());
|
||||
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
||||
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||
|
||||
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveAuthorizationRequestWhenExistingSessionAndDistributedSessionThenSaved() {
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
request.setSession(new MockDistributedHttpSession());
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().build();
|
||||
this.authorizationRequestRepository.saveAuthorizationRequest(
|
||||
authorizationRequest1, request, new MockHttpServletResponse());
|
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().build();
|
||||
this.authorizationRequestRepository.saveAuthorizationRequest(
|
||||
authorizationRequest2, request, new MockHttpServletResponse());
|
||||
|
||||
request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest2.getState());
|
||||
OAuth2AuthorizationRequest loadedAuthorizationRequest =
|
||||
this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
||||
|
||||
assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveAuthorizationRequestWhenNullThenRemoved() {
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
|
@ -220,4 +260,23 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
|
|||
.clientId("client-id-1234")
|
||||
.state("state-1234");
|
||||
}
|
||||
|
||||
static class MockDistributedHttpSession extends MockHttpSession {
|
||||
@Override
|
||||
public Object getAttribute(String name) {
|
||||
return wrap(super.getAttribute(name));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAttribute(String name, Object value) {
|
||||
super.setAttribute(name, wrap(value));
|
||||
}
|
||||
|
||||
private Object wrap(Object object) {
|
||||
if (object instanceof Map) {
|
||||
object = new HashMap<>((Map<Object, Object>) object);
|
||||
}
|
||||
return object;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue