Fix OAuth2 Client with Ditributed Session

Fixes: gh-6215
This commit is contained in:
Zhanwei Wang 2019-02-11 00:09:32 +08:00 committed by Rob Winch
parent 80081b0500
commit 4141ddbde2
2 changed files with 57 additions and 9 deletions

View File

@ -53,7 +53,7 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
if (state == null) {
return Mono.empty();
}
return getStateToAuthorizationRequest(exchange, false)
return getStateToAuthorizationRequest(exchange)
.filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state))
.map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state));
}
@ -62,9 +62,8 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
public Mono<Void> saveAuthorizationRequest(
OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) {
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
return getStateToAuthorizationRequest(exchange, true)
.doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest))
.then();
return saveStateToAuthorizationRequest(exchange).doOnNext(stateToAuthorizationRequest ->
stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest)).then();
}
@Override
@ -108,16 +107,28 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
return exchange.getSession().map(WebSession::getAttributes);
}
private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange, boolean create) {
private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange) {
Assert.notNull(exchange, "exchange cannot be null");
return getSessionAttributes(exchange)
.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
}
private Mono<Map<String, OAuth2AuthorizationRequest>> saveStateToAuthorizationRequest(ServerWebExchange exchange) {
Assert.notNull(exchange, "exchange cannot be null");
return getSessionAttributes(exchange)
.doOnNext(sessionAttrs -> {
if (create) {
sessionAttrs.putIfAbsent(this.sessionAttributeName, new HashMap<String, OAuth2AuthorizationRequest>());
Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName);
if (stateToAuthzRequest == null) {
stateToAuthzRequest = new HashMap<String, OAuth2AuthorizationRequest>();
}
})
.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
// No matter stateToAuthzRequest was in session or not, we should always put it into session again
// in case of redis or hazelcast session. #6215
sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
}).flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
}
private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(Map<String, Object> sessionAttrs) {

View File

@ -18,6 +18,13 @@ package org.springframework.security.oauth2.client.web.server;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
@ -99,6 +106,36 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
.verifyComplete();
}
@Test
public void multipleSavedAuthorizationRequestAndRedisCookie() {
String oldState = "state0";
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState).build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
Map<String, Object> sessionAttrs = spy(new HashMap<>());
WebSession session = mock(WebSession.class);
when(session.getAttributes()).thenReturn(sessionAttrs);
WebSessionManager sessionManager = e -> Mono.just(session);
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
Mono<Void> saveAndSave = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange));
StepVerifier.create(saveAndSave).verifyComplete();
verify(sessionAttrs, times(2)).put(any(), any());
}
@Test
public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
String oldState = "state0";