diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java index 58ea54f53e..14d04bb9ba 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import java.util.Map; * * @author Joe Grandja * @author Rob Winch + * @author Craig Andrews * @since 5.0 * @see AuthorizationRequestRepository * @see OAuth2AuthorizationRequest @@ -41,6 +42,8 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; + private boolean allowMultipleAuthorizationRequests; + @Override public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); @@ -63,9 +66,14 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au } String state = authorizationRequest.getState(); Assert.hasText(state, "authorizationRequest.state cannot be empty"); - Map authorizationRequests = this.getAuthorizationRequests(request); - authorizationRequests.put(state, authorizationRequest); - request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); + if (this.allowMultipleAuthorizationRequests) { + Map authorizationRequests = this.getAuthorizationRequests(request); + authorizationRequests.put(state, authorizationRequest); + request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); + } + else { + request.getSession().setAttribute(this.sessionAttributeName, authorizationRequest); + } } @Override @@ -77,11 +85,16 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au } Map authorizationRequests = this.getAuthorizationRequests(request); OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter); - if (!authorizationRequests.isEmpty()) { - request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); - } else { + if (authorizationRequests.size() == 0) { request.getSession().removeAttribute(this.sessionAttributeName); } + else if (authorizationRequests.size() == 1) { + request.getSession().setAttribute(this.sessionAttributeName, + authorizationRequests.values().iterator().next()); + } + else { + request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); + } return originalRequest; } @@ -107,11 +120,38 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au */ private Map getAuthorizationRequests(HttpServletRequest request) { HttpSession session = request.getSession(false); - Map authorizationRequests = session == null ? null : - (Map) session.getAttribute(this.sessionAttributeName); - if (authorizationRequests == null) { + Object sessionAttributeValue = (session != null) ? session.getAttribute(this.sessionAttributeName) : null; + if (sessionAttributeValue == null) { return new HashMap<>(); } - return authorizationRequests; + else if (sessionAttributeValue instanceof OAuth2AuthorizationRequest) { + OAuth2AuthorizationRequest auth2AuthorizationRequest = (OAuth2AuthorizationRequest) sessionAttributeValue; + Map authorizationRequests = new HashMap<>(1); + authorizationRequests.put(auth2AuthorizationRequest.getState(), auth2AuthorizationRequest); + return authorizationRequests; + } + else if (sessionAttributeValue instanceof Map) { + @SuppressWarnings("unchecked") + Map authorizationRequests = (Map) sessionAttributeValue; + return authorizationRequests; + } + else { + throw new IllegalStateException( + "authorizationRequests is supposed to be a Map or OAuth2AuthorizationRequest but actually is a " + + sessionAttributeValue.getClass()); + } + } + + /** + * Configure if multiple {@link OAuth2AuthorizationRequest}s should be stored per + * session. Default is false (not allow multiple {@link OAuth2AuthorizationRequest} + * per session). + * @param allowMultipleAuthorizationRequests true allows more than one + * {@link OAuth2AuthorizationRequest} to be stored per session. + * @since 5.5 + */ + @Deprecated + public void setAllowMultipleAuthorizationRequests(boolean allowMultipleAuthorizationRequests) { + this.allowMultipleAuthorizationRequests = allowMultipleAuthorizationRequests; } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java new file mode 100644 index 0000000000..0d245fb04c --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository} when + * {@link HttpSessionOAuth2AuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)} + * is enabled. + * + * @author Joe Grandja + * @author Craig Andrews + */ +public class HttpSessionOAuth2AuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests + extends HttpSessionOAuth2AuthorizationRequestRepositoryTests { + + @Before + public void setup() { + this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + this.authorizationRequestRepository.setAllowMultipleAuthorizationRequests(true); + } + + // gh-5110 + @Test + public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response); + String state3 = "state-5566"; + OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response); + request.addParameter(OAuth2ParameterNames.STATE, state1); + OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state2); + OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state3); + OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java new file mode 100644 index 0000000000..a9b1ebbe4e --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository} when + * {@link HttpSessionOAuth2AuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)} + * is disabled. + * + * @author Joe Grandja + * @author Craig Andrews + */ +public class HttpSessionOAuth2AuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests + extends HttpSessionOAuth2AuthorizationRequestRepositoryTests { + + @Before + public void setup() { + this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + this.authorizationRequestRepository.setAllowMultipleAuthorizationRequests(false); + } + + // gh-5145 + @Test + public void loadAuthorizationRequestWhenMultipleSavedThenReturnLastAuthorizationRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response); + String state3 = "state-5566"; + OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response); + request.addParameter(OAuth2ParameterNames.STATE, state1); + OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest1).isNull(); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state2); + OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest2).isNull(); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state3); + OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java index fe8fc5514a..e7b33bfe0b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,11 +34,12 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}. * * @author Joe Grandja + * @author Craig Andrews */ @RunWith(MockitoJUnitRunner.class) -public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { - private HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository = - new HttpSessionOAuth2AuthorizationRequestRepository(); +public abstract class HttpSessionOAuth2AuthorizationRequestRepositoryTests { + + protected HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository; @Test(expected = IllegalArgumentException.class) public void loadAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { @@ -70,42 +71,6 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest); } - // gh-5110 - @Test - public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() { - MockHttpServletRequest request = new MockHttpServletRequest(); - MockHttpServletResponse response = new MockHttpServletResponse(); - - String state1 = "state-1122"; - OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build(); - this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response); - - String state2 = "state-3344"; - OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build(); - this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response); - - String state3 = "state-5566"; - OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build(); - this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response); - - request.addParameter(OAuth2ParameterNames.STATE, state1); - OAuth2AuthorizationRequest loadedAuthorizationRequest1 = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1); - - request.removeParameter(OAuth2ParameterNames.STATE); - request.addParameter(OAuth2ParameterNames.STATE, state2); - OAuth2AuthorizationRequest loadedAuthorizationRequest2 = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2); - - request.removeParameter(OAuth2ParameterNames.STATE); - request.addParameter(OAuth2ParameterNames.STATE, state3); - OAuth2AuthorizationRequest loadedAuthorizationRequest3 = - this.authorizationRequestRepository.loadAuthorizationRequest(request); - assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3); - } - @Test public void loadAuthorizationRequestWhenSavedAndStateParameterNullThenReturnNull() { MockHttpServletRequest request = new MockHttpServletRequest(); @@ -284,11 +249,9 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { assertThat(removedAuthorizationRequest).isNull(); } - private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() { - return OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri("https://example.com/oauth2/authorize") - .clientId("client-id-1234") - .state("state-1234"); + protected OAuth2AuthorizationRequest.Builder createAuthorizationRequest() { + return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id-1234").state("state-1234"); } static class MockDistributedHttpSession extends MockHttpSession {