Add WebSessionOAuth2ReactiveAuthorizationRequestRepository

Issue: gh-4807
This commit is contained in:
Rob Winch 2018-04-25 12:41:06 -05:00
parent 5e9c714ff0
commit b613b2d253
3 changed files with 413 additions and 0 deletions

View File

@ -0,0 +1,69 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
/**
* Implementations of this interface are responsible for the persistence
* of {@link OAuth2AuthorizationRequest} between requests.
*
* <p>
* Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for persisting the Authorization Request
* before it initiates the authorization code grant flow.
* As well, used by the {@link OAuth2LoginAuthenticationFilter} for resolving
* the associated Authorization Request when handling the callback of the Authorization Response.
*
* @author Rob Winch
* @since 5.1
* @see OAuth2AuthorizationRequest
* @see HttpSessionOAuth2AuthorizationRequestRepository
*
* @param <T> The type of OAuth 2.0 Authorization Request
*/
public interface ReactiveAuthorizationRequestRepository<T extends OAuth2AuthorizationRequest> {
/**
* Returns the {@link OAuth2AuthorizationRequest} associated to the provided {@code HttpServletRequest}
* or {@code null} if not available.
*
* @param exchange the {@code ServerWebExchange}
* @return the {@link OAuth2AuthorizationRequest} or {@code null} if not available
*/
Mono<T> loadAuthorizationRequest(ServerWebExchange exchange);
/**
* Persists the {@link OAuth2AuthorizationRequest} associating it to
* the provided {@code HttpServletRequest} and/or {@code HttpServletResponse}.
*
* @param authorizationRequest the {@link OAuth2AuthorizationRequest}
* @param exchange the {@code ServerWebExchange}
*/
Mono<Void> saveAuthorizationRequest(T authorizationRequest, ServerWebExchange exchange);
/**
* Removes and returns the {@link OAuth2AuthorizationRequest} associated to the
* provided {@code HttpServletRequest} or if not available returns {@code null}.
*
* @param exchange the {@code ServerWebExchange}
* @return the removed {@link OAuth2AuthorizationRequest} or {@code null} if not available
*/
Mono<T> removeAuthorizationRequest(ServerWebExchange exchange);
}

View File

@ -0,0 +1,120 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import java.util.HashMap;
import java.util.Map;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
/**
* An implementation of an {@link ReactiveAuthorizationRequestRepository} that stores
* {@link OAuth2AuthorizationRequest} in the {@code WebSession}.
*
* @author Rob Winch
* @since 5.1
* @see AuthorizationRequestRepository
* @see OAuth2AuthorizationRequest
*/
public final class WebSessionOAuth2ReactiveAuthorizationRequestRepository implements ReactiveAuthorizationRequestRepository<OAuth2AuthorizationRequest> {
private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME =
WebSessionOAuth2ReactiveAuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST";
private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
@Override
public Mono<OAuth2AuthorizationRequest> loadAuthorizationRequest(
ServerWebExchange exchange) {
String state = getStateParameter(exchange);
if (state == null) {
return Mono.empty();
}
return getStateToAuthorizationRequest(exchange, false)
.filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state))
.map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state));
}
@Override
public Mono<Void> saveAuthorizationRequest(
OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) {
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
return getStateToAuthorizationRequest(exchange, true)
.doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest))
.then();
}
@Override
public Mono<OAuth2AuthorizationRequest> removeAuthorizationRequest(
ServerWebExchange exchange) {
String state = getStateParameter(exchange);
if (state == null) {
return Mono.empty();
}
return exchange.getSession()
.map(WebSession::getAttributes)
.handle((sessionAttrs, sink) -> {
Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(sessionAttrs);
if (stateToAuthzRequest == null) {
sink.complete();
return;
}
OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state);
if (stateToAuthzRequest.isEmpty()) {
sessionAttrs.remove(this.sessionAttributeName);
}
sink.next(removedValue);
});
}
/**
* Gets the state parameter from the {@link ServerHttpRequest}
* @param exchange the exchange to use
* @return the state parameter or null if not found
*/
private String getStateParameter(ServerWebExchange exchange) {
Assert.notNull(exchange, "exchange cannot be null");
return exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.STATE);
}
private Mono<Map<String, Object>> getSessionAttributes(ServerWebExchange exchange) {
return exchange.getSession().map(WebSession::getAttributes);
}
private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange, boolean create) {
Assert.notNull(exchange, "exchange cannot be null");
return getSessionAttributes(exchange)
.doOnNext(sessionAttrs -> {
if (create) {
sessionAttrs.putIfAbsent(this.sessionAttributeName, new HashMap<String, OAuth2AuthorizationRequest>());
}
})
.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
}
private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(Map<String, Object> sessionAttrs) {
return (Map<String, OAuth2AuthorizationRequest>) sessionAttrs.get(this.sessionAttributeName);
}
}

View File

@ -0,0 +1,224 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import java.util.Map;
import org.junit.Test;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.DefaultServerWebExchange;
import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
import org.springframework.web.server.session.WebSessionManager;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
/**
* @author Rob Winch
* @since 5.1
*/
public class WebSessionOAuth2ReactiveAuthorizationRequestRepositoryTests {
private WebSessionOAuth2ReactiveAuthorizationRequestRepository repository =
new WebSessionOAuth2ReactiveAuthorizationRequestRepository();
private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state("state")
.build();
private ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, "state"));
@Test
public void loadAuthorizatioNRequestWhenNullExchangeThenIllegalArgumentException() {
this.exchange = null;
assertThatThrownBy(() -> this.repository.loadAuthorizationRequest(this.exchange))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void loadAuthorizationRequestWhenNoSessionThenEmpty() {
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
assertSessionStartedIs(false);
}
@Test
public void loadAuthorizationRequestWhenSessionAndNoRequestThenEmpty() {
Mono<OAuth2AuthorizationRequest> setAttrThenLoad = this.exchange.getSession()
.map(WebSession::getAttributes).doOnNext(attrs -> attrs.put("foo", "bar"))
.then(this.repository.loadAuthorizationRequest(this.exchange));
StepVerifier.create(setAttrThenLoad)
.verifyComplete();
}
@Test
public void loadAuthorizationRequestWhenNoStateParamThenEmpty() {
this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)
.then(this.repository.loadAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndLoad)
.verifyComplete();
}
@Test
public void loadAuthorizationRequestWhenSavedThenAuthorizationRequest() {
Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)
.then(this.repository.loadAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndLoad)
.expectNext(this.authorizationRequest)
.verifyComplete();
}
@Test
public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
String oldState = "state0";
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState).build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
WebSessionManager sessionManager = e -> this.exchange.getSession();
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.loadAuthorizationRequest(oldExchange));
StepVerifier.create(saveAndSaveAndLoad)
.expectNext(oldAuthorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.expectNext(this.authorizationRequest)
.verifyComplete();
}
@Test
public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() {
this.authorizationRequest = null;
assertThatThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.isInstanceOf(IllegalArgumentException.class);
assertSessionStartedIs(false);
}
@Test
public void saveAuthorizationRequestWhenExchangeNullThenThrowsIllegalArgumentException() {
this.exchange = null;
assertThatThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void removeAuthorizationRequestWhenExchangeNullThenThrowsIllegalArgumentException() {
this.exchange = null;
assertThatThrownBy(() -> this.repository.removeAuthorizationRequest(this.exchange))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void removeAuthorizationRequestWhenNotPresentThenThrowsIllegalArgumentException() {
StepVerifier.create(this.repository.removeAuthorizationRequest(this.exchange))
.verifyComplete();
assertSessionStartedIs(false);
}
@Test
public void removeAuthorizationRequestWhenPresentThenFoundAndRemoved() {
Mono<OAuth2AuthorizationRequest> saveAndRemove = this.repository
.saveAuthorizationRequest(this.authorizationRequest, this.exchange)
.then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndRemove).expectNext(this.authorizationRequest)
.verifyComplete();
StepVerifier.create(this.exchange.getSession()
.map(WebSession::getAttributes)
.map(Map::isEmpty))
.expectNext(true)
.verifyComplete();
}
@Test
public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() {
String oldState = "state0";
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState).build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
WebSessionManager sessionManager = e -> this.exchange.getSession();
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndSaveAndRemove)
.expectNext(this.authorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange))
.expectNext(oldAuthorizationRequest)
.verifyComplete();
}
private void assertSessionStartedIs(boolean expected) {
Mono<Boolean> isStarted = this.exchange.getSession().map(WebSession::isStarted);
StepVerifier.create(isStarted)
.expectNext(expected)
.verifyComplete();
}
}