From ee9c8e2fd0f0345b5f130663138a396b98cd9418 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Tue, 15 Jun 2021 11:03:30 -0500 Subject: [PATCH] Store one request by default in WebSessionOAuth2ServerAuthorizationRequestRepository Related to gh-9649 Closes gh-9857 --- ...2ServerAuthorizationRequestRepository.java | 127 +++++---- ...lowMultipleAuthorizationRequestsTests.java | 252 ++++++++++++++++++ ...lowMultipleAuthorizationRequestsTests.java | 159 +++++++++++ ...erAuthorizationRequestRepositoryTests.java | 141 +--------- 4 files changed, 498 insertions(+), 181 deletions(-) create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java index 2ad3026ca4..3ca3b3c881 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import reactor.core.publisher.Mono; * {@link OAuth2AuthorizationRequest} in the {@code WebSession}. * * @author Rob Winch + * @author Steve Riesenberg * @since 5.1 * @see AuthorizationRequestRepository * @see OAuth2AuthorizationRequest @@ -46,6 +47,8 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; + private boolean allowMultipleAuthorizationRequests; + @Override public Mono loadAuthorizationRequest( ServerWebExchange exchange) { @@ -53,17 +56,33 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository if (state == null) { return Mono.empty(); } - return getStateToAuthorizationRequest(exchange) - .filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state)) - .map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state)); + // @formatter:off + return this.getSessionAttributes(exchange) + .filter((sessionAttrs) -> sessionAttrs.containsKey(this.sessionAttributeName)) + .map(this::getAuthorizationRequests) + .filter((stateToAuthorizationRequest) -> stateToAuthorizationRequest.containsKey(state)) + .map((stateToAuthorizationRequest) -> stateToAuthorizationRequest.get(state)); + // @formatter:on } @Override public Mono saveAuthorizationRequest( OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) { Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); - return saveStateToAuthorizationRequest(exchange) - .doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest)) + Assert.notNull(exchange, "exchange cannot be null"); + // @formatter:off + return getSessionAttributes(exchange) + .doOnNext((sessionAttrs) -> { + if (this.allowMultipleAuthorizationRequests) { + Map authorizationRequests = this.getAuthorizationRequests( + sessionAttrs); + authorizationRequests.put(authorizationRequest.getState(), authorizationRequest); + sessionAttrs.put(this.sessionAttributeName, authorizationRequests); + } + else { + sessionAttrs.put(this.sessionAttributeName, authorizationRequest); + } + }) .then(); } @@ -74,27 +93,24 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository if (state == null) { return Mono.empty(); } - return exchange.getSession() - .map(WebSession::getAttributes) - .handle((sessionAttrs, sink) -> { - Map stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(sessionAttrs); - if (stateToAuthzRequest == null) { - sink.complete(); - return; - } - OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state); - if (stateToAuthzRequest.isEmpty()) { - sessionAttrs.remove(this.sessionAttributeName); - } else if (removedValue != null) { - // gh-7327 Overwrite the existing Map to ensure the state is saved for distributed sessions - sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); - } - if (removedValue == null) { - sink.complete(); - } else { - sink.next(removedValue); - } - }); + // @formatter:off + return getSessionAttributes(exchange) + .flatMap((sessionAttrs) -> { + Map authorizationRequests = this.getAuthorizationRequests( + sessionAttrs); + OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(state); + if (authorizationRequests.isEmpty()) { + sessionAttrs.remove(this.sessionAttributeName); + } + else if (authorizationRequests.size() == 1) { + sessionAttrs.put(this.sessionAttributeName, authorizationRequests.values().iterator().next()); + } + else { + sessionAttrs.put(this.sessionAttributeName, authorizationRequests); + } + return Mono.justOrEmpty(originalRequest); + }); + // @formatter:on } /** @@ -111,31 +127,40 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository return exchange.getSession().map(WebSession::getAttributes); } - private Mono> getStateToAuthorizationRequest(ServerWebExchange exchange) { - Assert.notNull(exchange, "exchange cannot be null"); - - return getSessionAttributes(exchange) - .flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); + private Map getAuthorizationRequests(Map sessionAttrs) { + Object sessionAttributeValue = sessionAttrs.get(this.sessionAttributeName); + if (sessionAttributeValue == null) { + return new HashMap<>(); + } + else if (sessionAttributeValue instanceof OAuth2AuthorizationRequest) { + OAuth2AuthorizationRequest oauth2AuthorizationRequest = (OAuth2AuthorizationRequest) sessionAttributeValue; + Map authorizationRequests = new HashMap<>(1); + authorizationRequests.put(oauth2AuthorizationRequest.getState(), oauth2AuthorizationRequest); + return authorizationRequests; + } + else if (sessionAttributeValue instanceof Map) { + @SuppressWarnings("unchecked") + Map authorizationRequests = (Map) sessionAttrs + .get(this.sessionAttributeName); + return authorizationRequests; + } + else { + throw new IllegalStateException( + "authorizationRequests is supposed to be a Map or OAuth2AuthorizationRequest but actually is a " + + sessionAttributeValue.getClass()); + } } - private Mono> saveStateToAuthorizationRequest(ServerWebExchange exchange) { - Assert.notNull(exchange, "exchange cannot be null"); - - return getSessionAttributes(exchange) - .doOnNext(sessionAttrs -> { - Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName); - - if (stateToAuthzRequest == null) { - stateToAuthzRequest = new HashMap(); - } - - // No matter stateToAuthzRequest was in session or not, we should always put it into session again - // in case of redis or hazelcast session. #6215 - sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); - }).flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); - } - - private Map sessionAttrsMapStateToAuthorizationRequest(Map sessionAttrs) { - return (Map) sessionAttrs.get(this.sessionAttributeName); + /** + * Configure if multiple {@link OAuth2AuthorizationRequest}s should be stored per + * session. Default is false (not allow multiple {@link OAuth2AuthorizationRequest} + * per session). + * @param allowMultipleAuthorizationRequests true allows more than one + * {@link OAuth2AuthorizationRequest} to be stored per session. + * @since 5.5 + */ + @Deprecated + public void setAllowMultipleAuthorizationRequests(boolean allowMultipleAuthorizationRequests) { + this.allowMultipleAuthorizationRequests = allowMultipleAuthorizationRequests; } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java new file mode 100644 index 0000000000..5b3a014bf6 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java @@ -0,0 +1,252 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.server; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.session.WebSessionManager; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when + * {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)} + * is enabled. + * + * @author Steve Riesenberg + */ + +public class WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests + extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { + + @Before + public void setup() { + this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + this.repository.setAllowMultipleAuthorizationRequests(true); + } + + @Test + public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() { + String oldState = "state0"; + // @formatter:off + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + // @formatter:on + WebSessionManager sessionManager = (e) -> this.exchange.getSession(); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndLoad = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.loadAuthorizationRequest(oldExchange)); + StepVerifier.create(saveAndSaveAndLoad) + .expectNext(oldAuthorizationRequest) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .expectNext(this.authorizationRequest) + .verifyComplete(); + // @formatter:on + } + + // gh-5145 + @Test + public void loadAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenReturnOldAuthorizationRequest() { + // save 2 requests with legacy (allowMultipleAuthorizationRequests=true) and load + // with new + WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + legacy.setAllowMultipleAuthorizationRequests(true); + // @formatter:off + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state1) + .build(); + StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange)) + .verifyComplete(); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state2) + .build(); + StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange)) + .verifyComplete(); + ServerHttpRequest newRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state1) + .build(); + ServerWebExchange newExchange = this.exchange.mutate() + .request(newRequest) + .build(); + StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange)) + .expectNext(authorizationRequest1) + .verifyComplete(); + // @formatter:on + } + + // gh-5145 + @Test + public void saveAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenLoadNewAuthorizationRequest() { + // save 2 requests with legacy (allowMultipleAuthorizationRequests=true), save + // with new, and load with new + WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + legacy.setAllowMultipleAuthorizationRequests(true); + // @formatter:off + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state1) + .build(); + StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange)) + .verifyComplete(); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state2) + .build(); + StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange)) + .verifyComplete(); + String state3 = "state-5566"; + OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state3) + .build(); + ServerHttpRequest newRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state3) + .build(); + ServerWebExchange newExchange = this.exchange.mutate() + .request(newRequest) + .build(); + Mono saveAndLoad = this.repository + .saveAuthorizationRequest(authorizationRequest3, this.exchange) + .then(this.repository.loadAuthorizationRequest(newExchange)); + StepVerifier.create(saveAndLoad) + .expectNext(authorizationRequest3) + .verifyComplete(); + // @formatter:on + } + + @Test + public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() { + String oldState = "state0"; + // @formatter:off + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + // @formatter:on + WebSessionManager sessionManager = (e) -> this.exchange.getSession(); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndRemove = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.removeAuthorizationRequest(this.exchange)); + StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange)) + .expectNext(oldAuthorizationRequest) + .verifyComplete(); + // @formatter:on + } + + // gh-7327 + @Test + public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() { + String oldState = "state0"; + // @formatter:off + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + // @formatter:on + Map sessionAttrs = spy(new HashMap<>()); + WebSession session = mock(WebSession.class); + given(session.getAttributes()).willReturn(sessionAttrs); + WebSessionManager sessionManager = (e) -> Mono.just(session); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndRemove = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.removeAuthorizationRequest(this.exchange)); + StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .verifyComplete(); + // @formatter:on + verify(sessionAttrs, times(3)).put(any(), any()); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java new file mode 100644 index 0000000000..5e60fd298f --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java @@ -0,0 +1,159 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.server; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.session.WebSessionManager; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when + * {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)} + * is disabled. + * + * @author Steve Riesenberg + */ +public class WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests + extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { + + @Before + public void setup() { + this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); + this.repository.setAllowMultipleAuthorizationRequests(false); + } + + // gh-5145 + @Test + public void loadAuthorizationRequestWhenMultipleSavedThenReturnLastAuthorizationRequest() { + // @formatter:off + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state1) + .build(); + StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest1, this.exchange)) + .verifyComplete(); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state2) + .build(); + StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest2, this.exchange)) + .verifyComplete(); + String state3 = "state-5566"; + OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(state3) + .build(); + StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest3, this.exchange)) + .verifyComplete(); + ServerHttpRequest newRequest1 = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state1) + .build(); + ServerWebExchange newExchange1 = this.exchange.mutate() + .request(newRequest1) + .build(); + StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange1)) + .verifyComplete(); + ServerHttpRequest newRequest2 = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state2) + .build(); + ServerWebExchange newExchange2 = this.exchange.mutate() + .request(newRequest2) + .build(); + StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange2)) + .verifyComplete(); + ServerHttpRequest newRequest3 = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, state3) + .build(); + ServerWebExchange newExchange3 = this.exchange.mutate() + .request(newRequest3) + .build(); + StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange3)) + .expectNext(authorizationRequest3) + .verifyComplete(); + // @formatter:on + } + + // gh-5145 + @Test + public void removeAuthorizationRequestWhenMultipleThenSessionAttributeRemoved() { + String oldState = "state0"; + // @formatter:off + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState) + .build(); + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + // @formatter:on + Map sessionAttrs = spy(new HashMap<>()); + WebSession session = mock(WebSession.class); + given(session.getAttributes()).willReturn(sessionAttrs); + WebSessionManager sessionManager = (e) -> Mono.just(session); + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), + sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + // @formatter:off + Mono saveAndSaveAndRemove = this.repository + .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.removeAuthorizationRequest(this.exchange)); + StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest) + .verifyComplete(); + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .verifyComplete(); + // @formatter:on + verify(sessionAttrs, times(2)).put(anyString(), any()); + verify(sessionAttrs).remove(anyString()); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java index b4e11c05be..b8b699cfbe 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,51 +16,39 @@ package org.springframework.security.oauth2.client.web.server; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import java.util.HashMap; import java.util.Map; import org.junit.Test; -import org.springframework.http.codec.ServerCodecConfigurer; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + import org.springframework.mock.http.server.reactive.MockServerHttpRequest; -import org.springframework.mock.http.server.reactive.MockServerHttpResponse; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; -import org.springframework.web.server.adapter.DefaultServerWebExchange; -import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; -import org.springframework.web.server.session.WebSessionManager; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * @author Rob Winch * @since 5.1 */ -public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { +public abstract class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { - private WebSessionOAuth2ServerAuthorizationRequestRepository repository = - new WebSessionOAuth2ServerAuthorizationRequestRepository(); + protected WebSessionOAuth2ServerAuthorizationRequestRepository repository; - private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + // @formatter:off + protected OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://example.com/oauth2/authorize") .clientId("client-id") .redirectUri("http://localhost/client-1") .state("state") .build(); - private ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, "state")); + protected ServerWebExchange exchange = MockServerWebExchange + .from(MockServerHttpRequest.get("/").queryParam(OAuth2ParameterNames.STATE, "state")); @Test public void loadAuthorizationRequestWhenNullExchangeThenIllegalArgumentException() { @@ -106,39 +94,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { .verifyComplete(); } - @Test - public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() { - String oldState = "state0"; - MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState).build(); - - OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri("https://example.com/oauth2/authorize") - .clientId("client-id") - .redirectUri("http://localhost/client-1") - .state(oldState) - .build(); - - WebSessionManager sessionManager = e -> this.exchange.getSession(); - - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - - Mono saveAndSaveAndLoad = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) - .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) - .then(this.repository.loadAuthorizationRequest(oldExchange)); - - StepVerifier.create(saveAndSaveAndLoad) - .expectNext(oldAuthorizationRequest) - .verifyComplete(); - - StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) - .expectNext(this.authorizationRequest) - .verifyComplete(); - } - @Test public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() { this.authorizationRequest = null; @@ -203,80 +158,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { .verifyComplete(); } - @Test - public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() { - String oldState = "state0"; - MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState).build(); - - OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri("https://example.com/oauth2/authorize") - .clientId("client-id") - .redirectUri("http://localhost/client-1") - .state(oldState) - .build(); - - WebSessionManager sessionManager = e -> this.exchange.getSession(); - - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - - Mono saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) - .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) - .then(this.repository.removeAuthorizationRequest(this.exchange)); - - StepVerifier.create(saveAndSaveAndRemove) - .expectNext(this.authorizationRequest) - .verifyComplete(); - - StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) - .verifyComplete(); - - StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange)) - .expectNext(oldAuthorizationRequest) - .verifyComplete(); - } - - // gh-7327 - @Test - public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() { - String oldState = "state0"; - MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") - .queryParam(OAuth2ParameterNames.STATE, oldState).build(); - - OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri("https://example.com/oauth2/authorize") - .clientId("client-id") - .redirectUri("http://localhost/client-1") - .state(oldState) - .build(); - - Map sessionAttrs = spy(new HashMap<>()); - WebSession session = mock(WebSession.class); - when(session.getAttributes()).thenReturn(sessionAttrs); - WebSessionManager sessionManager = e -> Mono.just(session); - - this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); - - Mono saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) - .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) - .then(this.repository.removeAuthorizationRequest(this.exchange)); - - StepVerifier.create(saveAndSaveAndRemove) - .expectNext(this.authorizationRequest) - .verifyComplete(); - - StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) - .verifyComplete(); - - verify(sessionAttrs, times(3)).put(any(), any()); - } - private void assertSessionStartedIs(boolean expected) { Mono isStarted = this.exchange.getSession().map(WebSession::isStarted); StepVerifier.create(isStarted)