From a60fd43534db4dd515f4aea1ba5aef73a3913172 Mon Sep 17 00:00:00 2001 From: Zhanwei Wang Date: Mon, 11 Feb 2019 00:09:32 +0800 Subject: [PATCH] Fix OAuth2 Client with Ditributed Session Fixes: gh-6215 --- ...2ServerAuthorizationRequestRepository.java | 29 ++++++++++----- ...erAuthorizationRequestRepositoryTests.java | 37 +++++++++++++++++++ 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java index ceed2160e7..af98db8603 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java @@ -53,7 +53,7 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository if (state == null) { return Mono.empty(); } - return getStateToAuthorizationRequest(exchange, false) + return getStateToAuthorizationRequest(exchange) .filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state)) .map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state)); } @@ -62,9 +62,8 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository public Mono saveAuthorizationRequest( OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) { Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); - return getStateToAuthorizationRequest(exchange, true) - .doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest)) - .then(); + return saveStateToAuthorizationRequest(exchange).doOnNext(stateToAuthorizationRequest -> + stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest)).then(); } @Override @@ -108,16 +107,28 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository return exchange.getSession().map(WebSession::getAttributes); } - private Mono> getStateToAuthorizationRequest(ServerWebExchange exchange, boolean create) { + private Mono> getStateToAuthorizationRequest(ServerWebExchange exchange) { + Assert.notNull(exchange, "exchange cannot be null"); + + return getSessionAttributes(exchange) + .flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); + } + + private Mono> saveStateToAuthorizationRequest(ServerWebExchange exchange) { Assert.notNull(exchange, "exchange cannot be null"); return getSessionAttributes(exchange) .doOnNext(sessionAttrs -> { - if (create) { - sessionAttrs.putIfAbsent(this.sessionAttributeName, new HashMap()); + Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName); + + if (stateToAuthzRequest == null) { + stateToAuthzRequest = new HashMap(); } - }) - .flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); + + // No matter stateToAuthzRequest was in session or not, we should always put it into session again + // in case of redis or hazelcast session. #6215 + sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); + }).flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); } private Map sessionAttrsMapStateToAuthorizationRequest(Map sessionAttrs) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java index 5fe05b69ea..4b36d372b4 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java @@ -18,6 +18,13 @@ package org.springframework.security.oauth2.client.web.server; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import java.util.HashMap; import java.util.Map; import org.junit.Test; @@ -99,6 +106,36 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { .verifyComplete(); } + @Test + public void multipleSavedAuthorizationRequestAndRedisCookie() { + String oldState = "state0"; + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState).build(); + + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + + Map sessionAttrs = spy(new HashMap<>()); + WebSession session = mock(WebSession.class); + when(session.getAttributes()).thenReturn(sessionAttrs); + WebSessionManager sessionManager = e -> Mono.just(session); + + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + + Mono saveAndSave = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)); + + StepVerifier.create(saveAndSave).verifyComplete(); + verify(sessionAttrs, times(2)).put(any(), any()); + } + @Test public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() { String oldState = "state0";