diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/ReactiveAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/ReactiveAuthorizationRequestRepository.java new file mode 100644 index 0000000000..05ebd64770 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/ReactiveAuthorizationRequestRepository.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.web.server.ServerWebExchange; + +import reactor.core.publisher.Mono; + +/** + * Implementations of this interface are responsible for the persistence + * of {@link OAuth2AuthorizationRequest} between requests. + * + *

+ * Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for persisting the Authorization Request + * before it initiates the authorization code grant flow. + * As well, used by the {@link OAuth2LoginAuthenticationFilter} for resolving + * the associated Authorization Request when handling the callback of the Authorization Response. + * + * @author Rob Winch + * @since 5.1 + * @see OAuth2AuthorizationRequest + * @see HttpSessionOAuth2AuthorizationRequestRepository + * + * @param The type of OAuth 2.0 Authorization Request + */ +public interface ReactiveAuthorizationRequestRepository { + + /** + * Returns the {@link OAuth2AuthorizationRequest} associated to the provided {@code HttpServletRequest} + * or {@code null} if not available. + * + * @param exchange the {@code ServerWebExchange} + * @return the {@link OAuth2AuthorizationRequest} or {@code null} if not available + */ + Mono loadAuthorizationRequest(ServerWebExchange exchange); + + /** + * Persists the {@link OAuth2AuthorizationRequest} associating it to + * the provided {@code HttpServletRequest} and/or {@code HttpServletResponse}. + * + * @param authorizationRequest the {@link OAuth2AuthorizationRequest} + * @param exchange the {@code ServerWebExchange} + */ + Mono saveAuthorizationRequest(T authorizationRequest, ServerWebExchange exchange); + + /** + * Removes and returns the {@link OAuth2AuthorizationRequest} associated to the + * provided {@code HttpServletRequest} or if not available returns {@code null}. + * + * @param exchange the {@code ServerWebExchange} + * @return the removed {@link OAuth2AuthorizationRequest} or {@code null} if not available + */ + Mono removeAuthorizationRequest(ServerWebExchange exchange); +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/WebSessionOAuth2ReactiveAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/WebSessionOAuth2ReactiveAuthorizationRequestRepository.java new file mode 100644 index 0000000000..2783c5968f --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/WebSessionOAuth2ReactiveAuthorizationRequestRepository.java @@ -0,0 +1,120 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web; + +import java.util.HashMap; +import java.util.Map; + +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; + +import reactor.core.publisher.Mono; + +/** + * An implementation of an {@link ReactiveAuthorizationRequestRepository} that stores + * {@link OAuth2AuthorizationRequest} in the {@code WebSession}. + * + * @author Rob Winch + * @since 5.1 + * @see AuthorizationRequestRepository + * @see OAuth2AuthorizationRequest + */ +public final class WebSessionOAuth2ReactiveAuthorizationRequestRepository implements ReactiveAuthorizationRequestRepository { + + private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME = + WebSessionOAuth2ReactiveAuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST"; + + private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; + + @Override + public Mono loadAuthorizationRequest( + ServerWebExchange exchange) { + String state = getStateParameter(exchange); + if (state == null) { + return Mono.empty(); + } + return getStateToAuthorizationRequest(exchange, false) + .filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state)) + .map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state)); + } + + @Override + public Mono saveAuthorizationRequest( + OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) { + Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); + return getStateToAuthorizationRequest(exchange, true) + .doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest)) + .then(); + } + + @Override + public Mono removeAuthorizationRequest( + ServerWebExchange exchange) { + String state = getStateParameter(exchange); + if (state == null) { + return Mono.empty(); + } + return exchange.getSession() + .map(WebSession::getAttributes) + .handle((sessionAttrs, sink) -> { + Map stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(sessionAttrs); + if (stateToAuthzRequest == null) { + sink.complete(); + return; + } + OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state); + if (stateToAuthzRequest.isEmpty()) { + sessionAttrs.remove(this.sessionAttributeName); + } + sink.next(removedValue); + }); + } + + /** + * Gets the state parameter from the {@link ServerHttpRequest} + * @param exchange the exchange to use + * @return the state parameter or null if not found + */ + private String getStateParameter(ServerWebExchange exchange) { + Assert.notNull(exchange, "exchange cannot be null"); + return exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.STATE); + } + + private Mono> getSessionAttributes(ServerWebExchange exchange) { + return exchange.getSession().map(WebSession::getAttributes); + } + + private Mono> getStateToAuthorizationRequest(ServerWebExchange exchange, boolean create) { + Assert.notNull(exchange, "exchange cannot be null"); + + return getSessionAttributes(exchange) + .doOnNext(sessionAttrs -> { + if (create) { + sessionAttrs.putIfAbsent(this.sessionAttributeName, new HashMap()); + } + }) + .flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); + } + + private Map sessionAttrsMapStateToAuthorizationRequest(Map sessionAttrs) { + return (Map) sessionAttrs.get(this.sessionAttributeName); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/WebSessionOAuth2ReactiveAuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/WebSessionOAuth2ReactiveAuthorizationRequestRepositoryTests.java new file mode 100644 index 0000000000..c4fc7adaed --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/WebSessionOAuth2ReactiveAuthorizationRequestRepositoryTests.java @@ -0,0 +1,224 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Map; + +import org.junit.Test; +import org.springframework.http.codec.ServerCodecConfigurer; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpResponse; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; +import org.springframework.web.server.session.WebSessionManager; + +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * @author Rob Winch + * @since 5.1 + */ +public class WebSessionOAuth2ReactiveAuthorizationRequestRepositoryTests { + + private WebSessionOAuth2ReactiveAuthorizationRequestRepository repository = + new WebSessionOAuth2ReactiveAuthorizationRequestRepository(); + + private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state("state") + .build(); + + private ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, "state")); + + @Test + public void loadAuthorizatioNRequestWhenNullExchangeThenIllegalArgumentException() { + this.exchange = null; + assertThatThrownBy(() -> this.repository.loadAuthorizationRequest(this.exchange)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void loadAuthorizationRequestWhenNoSessionThenEmpty() { + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .verifyComplete(); + + assertSessionStartedIs(false); + } + + @Test + public void loadAuthorizationRequestWhenSessionAndNoRequestThenEmpty() { + Mono setAttrThenLoad = this.exchange.getSession() + .map(WebSession::getAttributes).doOnNext(attrs -> attrs.put("foo", "bar")) + .then(this.repository.loadAuthorizationRequest(this.exchange)); + + StepVerifier.create(setAttrThenLoad) + .verifyComplete(); + } + + @Test + public void loadAuthorizationRequestWhenNoStateParamThenEmpty() { + this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); + Mono saveAndLoad = this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange) + .then(this.repository.loadAuthorizationRequest(this.exchange)); + + StepVerifier.create(saveAndLoad) + .verifyComplete(); + } + + @Test + public void loadAuthorizationRequestWhenSavedThenAuthorizationRequest() { + Mono saveAndLoad = this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange) + .then(this.repository.loadAuthorizationRequest(this.exchange)); + StepVerifier.create(saveAndLoad) + .expectNext(this.authorizationRequest) + .verifyComplete(); + } + + @Test + public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() { + String oldState = "state0"; + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState).build(); + + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + + WebSessionManager sessionManager = e -> this.exchange.getSession(); + + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + + Mono saveAndSaveAndLoad = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.loadAuthorizationRequest(oldExchange)); + + StepVerifier.create(saveAndSaveAndLoad) + .expectNext(oldAuthorizationRequest) + .verifyComplete(); + + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .expectNext(this.authorizationRequest) + .verifyComplete(); + } + + @Test + public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() { + this.authorizationRequest = null; + assertThatThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .isInstanceOf(IllegalArgumentException.class); + assertSessionStartedIs(false); + + } + + @Test + public void saveAuthorizationRequestWhenExchangeNullThenThrowsIllegalArgumentException() { + this.exchange = null; + assertThatThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .isInstanceOf(IllegalArgumentException.class); + + } + + @Test + public void removeAuthorizationRequestWhenExchangeNullThenThrowsIllegalArgumentException() { + this.exchange = null; + assertThatThrownBy(() -> this.repository.removeAuthorizationRequest(this.exchange)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void removeAuthorizationRequestWhenNotPresentThenThrowsIllegalArgumentException() { + StepVerifier.create(this.repository.removeAuthorizationRequest(this.exchange)) + .verifyComplete(); + assertSessionStartedIs(false); + } + + @Test + public void removeAuthorizationRequestWhenPresentThenFoundAndRemoved() { + Mono saveAndRemove = this.repository + .saveAuthorizationRequest(this.authorizationRequest, this.exchange) + .then(this.repository.removeAuthorizationRequest(this.exchange)); + + StepVerifier.create(saveAndRemove).expectNext(this.authorizationRequest) + .verifyComplete(); + + StepVerifier.create(this.exchange.getSession() + .map(WebSession::getAttributes) + .map(Map::isEmpty)) + .expectNext(true) + .verifyComplete(); + } + + @Test + public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() { + String oldState = "state0"; + MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") + .queryParam(OAuth2ParameterNames.STATE, oldState).build(); + + OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("https://example.com/oauth2/authorize") + .clientId("client-id") + .redirectUri("http://localhost/client-1") + .state(oldState) + .build(); + + WebSessionManager sessionManager = e -> this.exchange.getSession(); + + this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager, + ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + + Mono saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) + .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) + .then(this.repository.removeAuthorizationRequest(this.exchange)); + + StepVerifier.create(saveAndSaveAndRemove) + .expectNext(this.authorizationRequest) + .verifyComplete(); + + StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) + .verifyComplete(); + + StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange)) + .expectNext(oldAuthorizationRequest) + .verifyComplete(); + } + + private void assertSessionStartedIs(boolean expected) { + Mono isStarted = this.exchange.getSession().map(WebSession::isStarted); + StepVerifier.create(isStarted) + .expectNext(expected) + .verifyComplete(); + } +}