diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java index c40a12234f..2ad3026ca4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -85,6 +85,9 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state); if (stateToAuthzRequest.isEmpty()) { sessionAttrs.remove(this.sessionAttributeName); + } else if (removedValue != null) { + // gh-7327 Overwrite the existing Map to ensure the state is saved for distributed sessions + sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); } if (removedValue == null) { sink.complete(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java index 79d17eabd1..b4e11c05be 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,7 +63,7 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { .queryParam(OAuth2ParameterNames.STATE, "state")); @Test - public void loadAuthorizatioNRequestWhenNullExchangeThenIllegalArgumentException() { + public void loadAuthorizationRequestWhenNullExchangeThenIllegalArgumentException() { this.exchange = null; assertThatThrownBy(() -> this.repository.loadAuthorizationRequest(this.exchange)) .isInstanceOf(IllegalArgumentException.class); @@ -106,36 +106,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { .verifyComplete(); } - @Test - public void multipleSavedAuthorizationRequestAndRedisCookie() { - String oldState = "state0"; - MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState).build(); - - OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri("https://example.com/oauth2/authorize") - .clientId("client-id") - .redirectUri("http://localhost/client-1") - .state(oldState) - .build(); - - Map sessionAttrs = spy(new HashMap<>()); - WebSession session = mock(WebSession.class); - when(session.getAttributes()).thenReturn(sessionAttrs); - WebSessionManager sessionManager = e -> Mono.just(session); - - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - - Mono saveAndSave = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) - .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)); - - StepVerifier.create(saveAndSave).verifyComplete(); - verify(sessionAttrs, times(2)).put(any(), any()); - } - @Test public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() { String oldState = "state0"; @@ -269,6 +239,44 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { .verifyComplete(); } + // gh-7327 + @Test + public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() { + String oldState = "state0"; + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState).build(); + + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + + Map sessionAttrs = spy(new HashMap<>()); + WebSession session = mock(WebSession.class); + when(session.getAttributes()).thenReturn(sessionAttrs); + WebSessionManager sessionManager = e -> Mono.just(session); + + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + + Mono saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.removeAuthorizationRequest(this.exchange)); + + StepVerifier.create(saveAndSaveAndRemove) + .expectNext(this.authorizationRequest) + .verifyComplete(); + + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .verifyComplete(); + + verify(sessionAttrs, times(3)).put(any(), any()); + } + private void assertSessionStartedIs(boolean expected) { Mono isStarted = this.exchange.getSession().map(WebSession::isStarted); StepVerifier.create(isStarted)