From 700bda68b7b4507899221fe6774926ce0e8d9f21 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Tue, 15 Jun 2021 11:03:30 -0500 Subject: [PATCH] Store one request by default in WebSessionOAuth2ServerAuthorizationRequestRepository Related to gh-9649 Closes gh-9857 --- ...2ServerAuthorizationRequestRepository.java | 113 ++++---- ...lowMultipleAuthorizationRequestsTests.java | 252 ++++++++++++++++++ ...lowMultipleAuthorizationRequestsTests.java | 159 +++++++++++ ...erAuthorizationRequestRepositoryTests.java | 125 +-------- 4 files changed, 478 insertions(+), 171 deletions(-) create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java index f56953da55..47eee77365 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import org.springframework.web.server.WebSession; * {@link OAuth2AuthorizationRequest} in the {@code WebSession}. * * @author Rob Winch + * @author Steve Riesenberg * @since 5.1 * @see AuthorizationRequestRepository * @see OAuth2AuthorizationRequest @@ -46,6 +47,8 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; + private boolean allowMultipleAuthorizationRequests; + @Override public Mono loadAuthorizationRequest(ServerWebExchange exchange) { String state = getStateParameter(exchange); @@ -53,7 +56,9 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository return Mono.empty(); } // @formatter:off - return getStateToAuthorizationRequest(exchange) + return this.getSessionAttributes(exchange) + .filter((sessionAttrs) -> sessionAttrs.containsKey(this.sessionAttributeName)) + .map(this::getAuthorizationRequests) .filter((stateToAuthorizationRequest) -> stateToAuthorizationRequest.containsKey(state)) .map((stateToAuthorizationRequest) -> stateToAuthorizationRequest.get(state)); // @formatter:on @@ -63,10 +68,20 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository public Mono saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) { Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); + Assert.notNull(exchange, "exchange cannot be null"); // @formatter:off - return saveStateToAuthorizationRequest(exchange) - .doOnNext((stateToAuthorizationRequest) -> stateToAuthorizationRequest - .put(authorizationRequest.getState(), authorizationRequest)) + return getSessionAttributes(exchange) + .doOnNext((sessionAttrs) -> { + if (this.allowMultipleAuthorizationRequests) { + Map authorizationRequests = this.getAuthorizationRequests( + sessionAttrs); + authorizationRequests.put(authorizationRequest.getState(), authorizationRequest); + sessionAttrs.put(this.sessionAttributeName, authorizationRequests); + } + else { + sessionAttrs.put(this.sessionAttributeName, authorizationRequest); + } + }) .then(); // @formatter:on } @@ -78,30 +93,21 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository return Mono.empty(); } // @formatter:off - return exchange.getSession() - .map(WebSession::getAttributes) - .handle((sessionAttrs, sink) -> { - Map stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest( + return getSessionAttributes(exchange) + .flatMap((sessionAttrs) -> { + Map authorizationRequests = this.getAuthorizationRequests( sessionAttrs); - if (stateToAuthzRequest == null) { - sink.complete(); - return; - } - OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state); - if (stateToAuthzRequest.isEmpty()) { + OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(state); + if (authorizationRequests.isEmpty()) { sessionAttrs.remove(this.sessionAttributeName); } - else if (removedValue != null) { - // gh-7327 Overwrite the existing Map to ensure the state is saved for - // distributed sessions - sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); - } - if (removedValue == null) { - sink.complete(); + else if (authorizationRequests.size() == 1) { + sessionAttrs.put(this.sessionAttributeName, authorizationRequests.values().iterator().next()); } else { - sink.next(removedValue); + sessionAttrs.put(this.sessionAttributeName, authorizationRequests); } + return Mono.justOrEmpty(originalRequest); }); // @formatter:on } @@ -120,36 +126,41 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository return exchange.getSession().map(WebSession::getAttributes); } - private Mono> getStateToAuthorizationRequest(ServerWebExchange exchange) { - Assert.notNull(exchange, "exchange cannot be null"); - - // @formatter:off - return getSessionAttributes(exchange) - .flatMap((sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); - // @formatter:on + private Map getAuthorizationRequests(Map sessionAttrs) { + Object sessionAttributeValue = sessionAttrs.get(this.sessionAttributeName); + if (sessionAttributeValue == null) { + return new HashMap<>(); + } + else if (sessionAttributeValue instanceof OAuth2AuthorizationRequest) { + OAuth2AuthorizationRequest oauth2AuthorizationRequest = (OAuth2AuthorizationRequest) sessionAttributeValue; + Map authorizationRequests = new HashMap<>(1); + authorizationRequests.put(oauth2AuthorizationRequest.getState(), oauth2AuthorizationRequest); + return authorizationRequests; + } + else if (sessionAttributeValue instanceof Map) { + @SuppressWarnings("unchecked") + Map authorizationRequests = (Map) sessionAttrs + .get(this.sessionAttributeName); + return authorizationRequests; + } + else { + throw new IllegalStateException( + "authorizationRequests is supposed to be a Map or OAuth2AuthorizationRequest but actually is a " + + sessionAttributeValue.getClass()); + } } - private Mono> saveStateToAuthorizationRequest(ServerWebExchange exchange) { - Assert.notNull(exchange, "exchange cannot be null"); - // @formatter:off - return getSessionAttributes(exchange) - .doOnNext((sessionAttrs) -> { - Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName); - if (stateToAuthzRequest == null) { - stateToAuthzRequest = new HashMap(); - } - // No matter stateToAuthzRequest was in session or not, we should always put - // it into session again - // in case of redis or hazelcast session. #6215 - sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); - }) - .flatMap((sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); - // @formatter:on - } - - private Map sessionAttrsMapStateToAuthorizationRequest( - Map sessionAttrs) { - return (Map) sessionAttrs.get(this.sessionAttributeName); + /** + * Configure if multiple {@link OAuth2AuthorizationRequest}s should be stored per + * session. Default is false (not allow multiple {@link OAuth2AuthorizationRequest} + * per session). + * @param allowMultipleAuthorizationRequests true allows more than one + * {@link OAuth2AuthorizationRequest} to be stored per session. + * @since 5.5 + */ + @Deprecated + public void setAllowMultipleAuthorizationRequests(boolean allowMultipleAuthorizationRequests) { + this.allowMultipleAuthorizationRequests = allowMultipleAuthorizationRequests; } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java new file mode 100644 index 0000000000..5b3a014bf6 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java @@ -0,0 +1,252 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.server; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.session.WebSessionManager; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when + * {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)} + * is enabled. + * + * @author Steve Riesenberg + */ + +public class WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests + extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { + + @Before + public void setup() { + this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + this.repository.setAllowMultipleAuthorizationRequests(true); + } + + @Test + public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() { + String oldState = "state0"; + // @formatter:off + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + // @formatter:on + WebSessionManager sessionManager = (e) -> this.exchange.getSession(); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndLoad = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.loadAuthorizationRequest(oldExchange)); + StepVerifier.create(saveAndSaveAndLoad) + .expectNext(oldAuthorizationRequest) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .expectNext(this.authorizationRequest) + .verifyComplete(); + // @formatter:on + } + + // gh-5145 + @Test + public void loadAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenReturnOldAuthorizationRequest() { + // save 2 requests with legacy (allowMultipleAuthorizationRequests=true) and load + // with new + WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + legacy.setAllowMultipleAuthorizationRequests(true); + // @formatter:off + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state1) + .build(); + StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange)) + .verifyComplete(); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state2) + .build(); + StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange)) + .verifyComplete(); + ServerHttpRequest newRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state1) + .build(); + ServerWebExchange newExchange = this.exchange.mutate() + .request(newRequest) + .build(); + StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange)) + .expectNext(authorizationRequest1) + .verifyComplete(); + // @formatter:on + } + + // gh-5145 + @Test + public void saveAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenLoadNewAuthorizationRequest() { + // save 2 requests with legacy (allowMultipleAuthorizationRequests=true), save + // with new, and load with new + WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + legacy.setAllowMultipleAuthorizationRequests(true); + // @formatter:off + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state1) + .build(); + StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange)) + .verifyComplete(); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state2) + .build(); + StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange)) + .verifyComplete(); + String state3 = "state-5566"; + OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state3) + .build(); + ServerHttpRequest newRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state3) + .build(); + ServerWebExchange newExchange = this.exchange.mutate() + .request(newRequest) + .build(); + Mono saveAndLoad = this.repository + .saveAuthorizationRequest(authorizationRequest3, this.exchange) + .then(this.repository.loadAuthorizationRequest(newExchange)); + StepVerifier.create(saveAndLoad) + .expectNext(authorizationRequest3) + .verifyComplete(); + // @formatter:on + } + + @Test + public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() { + String oldState = "state0"; + // @formatter:off + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + // @formatter:on + WebSessionManager sessionManager = (e) -> this.exchange.getSession(); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndRemove = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.removeAuthorizationRequest(this.exchange)); + StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange)) + .expectNext(oldAuthorizationRequest) + .verifyComplete(); + // @formatter:on + } + + // gh-7327 + @Test + public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() { + String oldState = "state0"; + // @formatter:off + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + // @formatter:on + Map sessionAttrs = spy(new HashMap<>()); + WebSession session = mock(WebSession.class); + given(session.getAttributes()).willReturn(sessionAttrs); + WebSessionManager sessionManager = (e) -> Mono.just(session); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndRemove = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.removeAuthorizationRequest(this.exchange)); + StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .verifyComplete(); + // @formatter:on + verify(sessionAttrs, times(3)).put(any(), any()); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java new file mode 100644 index 0000000000..5e60fd298f --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java @@ -0,0 +1,159 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.server; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.session.WebSessionManager; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when + * {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)} + * is disabled. + * + * @author Steve Riesenberg + */ +public class WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests + extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { + + @Before + public void setup() { + this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + this.repository.setAllowMultipleAuthorizationRequests(false); + } + + // gh-5145 + @Test + public void loadAuthorizationRequestWhenMultipleSavedThenReturnLastAuthorizationRequest() { + // @formatter:off + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state1) + .build(); + StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest1, this.exchange)) + .verifyComplete(); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state2) + .build(); + StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest2, this.exchange)) + .verifyComplete(); + String state3 = "state-5566"; + OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state3) + .build(); + StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest3, this.exchange)) + .verifyComplete(); + ServerHttpRequest newRequest1 = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state1) + .build(); + ServerWebExchange newExchange1 = this.exchange.mutate() + .request(newRequest1) + .build(); + StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange1)) + .verifyComplete(); + ServerHttpRequest newRequest2 = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state2) + .build(); + ServerWebExchange newExchange2 = this.exchange.mutate() + .request(newRequest2) + .build(); + StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange2)) + .verifyComplete(); + ServerHttpRequest newRequest3 = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state3) + .build(); + ServerWebExchange newExchange3 = this.exchange.mutate() + .request(newRequest3) + .build(); + StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange3)) + .expectNext(authorizationRequest3) + .verifyComplete(); + // @formatter:on + } + + // gh-5145 + @Test + public void removeAuthorizationRequestWhenMultipleThenSessionAttributeRemoved() { + String oldState = "state0"; + // @formatter:off + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + // @formatter:on + Map sessionAttrs = spy(new HashMap<>()); + WebSession session = mock(WebSession.class); + given(session.getAttributes()).willReturn(sessionAttrs); + WebSessionManager sessionManager = (e) -> Mono.just(session); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndRemove = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.removeAuthorizationRequest(this.exchange)); + StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .verifyComplete(); + // @formatter:on + verify(sessionAttrs, times(2)).put(anyString(), any()); + verify(sessionAttrs).remove(anyString()); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java index adae7ac311..c3470b68a2 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,43 +16,31 @@ package org.springframework.security.oauth2.client.web.server; -import java.util.HashMap; import java.util.Map; import org.junit.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; -import org.springframework.mock.http.server.reactive.MockServerHttpResponse; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; -import org.springframework.web.server.adapter.DefaultServerWebExchange; -import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; -import org.springframework.web.server.session.WebSessionManager; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; /** * @author Rob Winch * @since 5.1 */ -public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { +public abstract class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { - private WebSessionOAuth2ServerAuthorizationRequestRepository repository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + protected WebSessionOAuth2ServerAuthorizationRequestRepository repository; // @formatter:off - private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + protected OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://example.com/oauth2/authorize") .clientId("client-id") .redirectUri("http://localhost/client-1") @@ -60,7 +48,7 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { .build(); // @formatter:on - private ServerWebExchange exchange = MockServerWebExchange + protected ServerWebExchange exchange = MockServerWebExchange .from(MockServerHttpRequest.get("/").queryParam(OAuth2ParameterNames.STATE, "state")); @Test @@ -114,39 +102,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { // @formatter:on } - @Test - public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() { - String oldState = "state0"; - // @formatter:off - MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState) - .build(); - OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri("https://example.com/oauth2/authorize") - .clientId("client-id") - .redirectUri("http://localhost/client-1") - .state(oldState) - .build(); - // @formatter:on - WebSessionManager sessionManager = (e) -> this.exchange.getSession(); - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), - sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), - sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - // @formatter:off - Mono saveAndSaveAndLoad = this.repository - .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) - .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) - .then(this.repository.loadAuthorizationRequest(oldExchange)); - StepVerifier.create(saveAndSaveAndLoad) - .expectNext(oldAuthorizationRequest) - .verifyComplete(); - StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) - .expectNext(this.authorizationRequest) - .verifyComplete(); - // @formatter:on - } - @Test public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() { this.authorizationRequest = null; @@ -211,76 +166,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { // @formatter:on } - @Test - public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() { - String oldState = "state0"; - // @formatter:off - MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState) - .build(); - OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri("https://example.com/oauth2/authorize") - .clientId("client-id") - .redirectUri("http://localhost/client-1") - .state(oldState) - .build(); - // @formatter:on - WebSessionManager sessionManager = (e) -> this.exchange.getSession(); - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), - sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), - sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - // @formatter:off - Mono saveAndSaveAndRemove = this.repository - .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) - .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) - .then(this.repository.removeAuthorizationRequest(this.exchange)); - StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) - .verifyComplete(); - StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) - .verifyComplete(); - StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange)) - .expectNext(oldAuthorizationRequest) - .verifyComplete(); - // @formatter:on - } - - // gh-7327 - @Test - public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() { - String oldState = "state0"; - // @formatter:off - MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState) - .build(); - OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri("https://example.com/oauth2/authorize") - .clientId("client-id") - .redirectUri("http://localhost/client-1") - .state(oldState) - .build(); - // @formatter:on - Map sessionAttrs = spy(new HashMap<>()); - WebSession session = mock(WebSession.class); - given(session.getAttributes()).willReturn(sessionAttrs); - WebSessionManager sessionManager = (e) -> Mono.just(session); - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), - sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), - sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - // @formatter:off - Mono saveAndSaveAndRemove = this.repository - .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) - .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) - .then(this.repository.removeAuthorizationRequest(this.exchange)); - StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) - .verifyComplete(); - StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) - .verifyComplete(); - // @formatter:on - verify(sessionAttrs, times(3)).put(any(), any()); - } - private void assertSessionStartedIs(boolean expected) { // @formatter:off Mono isStarted = this.exchange.getSession()