Store one request by default in WebSessionOAuth2ServerAuthorizationRequestRepository

Related to gh-9649
Closes gh-9857
This commit is contained in:
Steve Riesenberg 2021-06-15 11:03:30 -05:00 committed by Steve Riesenberg
parent e16b88c9d5
commit ee9c8e2fd0
4 changed files with 498 additions and 181 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -34,6 +34,7 @@ import reactor.core.publisher.Mono;
* {@link OAuth2AuthorizationRequest} in the {@code WebSession}. * {@link OAuth2AuthorizationRequest} in the {@code WebSession}.
* *
* @author Rob Winch * @author Rob Winch
* @author Steve Riesenberg
* @since 5.1 * @since 5.1
* @see AuthorizationRequestRepository * @see AuthorizationRequestRepository
* @see OAuth2AuthorizationRequest * @see OAuth2AuthorizationRequest
@ -46,6 +47,8 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
private boolean allowMultipleAuthorizationRequests;
@Override @Override
public Mono<OAuth2AuthorizationRequest> loadAuthorizationRequest( public Mono<OAuth2AuthorizationRequest> loadAuthorizationRequest(
ServerWebExchange exchange) { ServerWebExchange exchange) {
@ -53,17 +56,33 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
if (state == null) { if (state == null) {
return Mono.empty(); return Mono.empty();
} }
return getStateToAuthorizationRequest(exchange) // @formatter:off
.filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state)) return this.getSessionAttributes(exchange)
.map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state)); .filter((sessionAttrs) -> sessionAttrs.containsKey(this.sessionAttributeName))
.map(this::getAuthorizationRequests)
.filter((stateToAuthorizationRequest) -> stateToAuthorizationRequest.containsKey(state))
.map((stateToAuthorizationRequest) -> stateToAuthorizationRequest.get(state));
// @formatter:on
} }
@Override @Override
public Mono<Void> saveAuthorizationRequest( public Mono<Void> saveAuthorizationRequest(
OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) { OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) {
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
return saveStateToAuthorizationRequest(exchange) Assert.notNull(exchange, "exchange cannot be null");
.doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest)) // @formatter:off
return getSessionAttributes(exchange)
.doOnNext((sessionAttrs) -> {
if (this.allowMultipleAuthorizationRequests) {
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(
sessionAttrs);
authorizationRequests.put(authorizationRequest.getState(), authorizationRequest);
sessionAttrs.put(this.sessionAttributeName, authorizationRequests);
}
else {
sessionAttrs.put(this.sessionAttributeName, authorizationRequest);
}
})
.then(); .then();
} }
@ -74,27 +93,24 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
if (state == null) { if (state == null) {
return Mono.empty(); return Mono.empty();
} }
return exchange.getSession() // @formatter:off
.map(WebSession::getAttributes) return getSessionAttributes(exchange)
.handle((sessionAttrs, sink) -> { .flatMap((sessionAttrs) -> {
Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(sessionAttrs); Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(
if (stateToAuthzRequest == null) { sessionAttrs);
sink.complete(); OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(state);
return; if (authorizationRequests.isEmpty()) {
} sessionAttrs.remove(this.sessionAttributeName);
OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state); }
if (stateToAuthzRequest.isEmpty()) { else if (authorizationRequests.size() == 1) {
sessionAttrs.remove(this.sessionAttributeName); sessionAttrs.put(this.sessionAttributeName, authorizationRequests.values().iterator().next());
} else if (removedValue != null) { }
// gh-7327 Overwrite the existing Map to ensure the state is saved for distributed sessions else {
sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); sessionAttrs.put(this.sessionAttributeName, authorizationRequests);
} }
if (removedValue == null) { return Mono.justOrEmpty(originalRequest);
sink.complete(); });
} else { // @formatter:on
sink.next(removedValue);
}
});
} }
/** /**
@ -111,31 +127,40 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
return exchange.getSession().map(WebSession::getAttributes); return exchange.getSession().map(WebSession::getAttributes);
} }
private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange) { private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(Map<String, Object> sessionAttrs) {
Assert.notNull(exchange, "exchange cannot be null"); Object sessionAttributeValue = sessionAttrs.get(this.sessionAttributeName);
if (sessionAttributeValue == null) {
return getSessionAttributes(exchange) return new HashMap<>();
.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); }
else if (sessionAttributeValue instanceof OAuth2AuthorizationRequest) {
OAuth2AuthorizationRequest oauth2AuthorizationRequest = (OAuth2AuthorizationRequest) sessionAttributeValue;
Map<String, OAuth2AuthorizationRequest> authorizationRequests = new HashMap<>(1);
authorizationRequests.put(oauth2AuthorizationRequest.getState(), oauth2AuthorizationRequest);
return authorizationRequests;
}
else if (sessionAttributeValue instanceof Map) {
@SuppressWarnings("unchecked")
Map<String, OAuth2AuthorizationRequest> authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) sessionAttrs
.get(this.sessionAttributeName);
return authorizationRequests;
}
else {
throw new IllegalStateException(
"authorizationRequests is supposed to be a Map or OAuth2AuthorizationRequest but actually is a "
+ sessionAttributeValue.getClass());
}
} }
private Mono<Map<String, OAuth2AuthorizationRequest>> saveStateToAuthorizationRequest(ServerWebExchange exchange) { /**
Assert.notNull(exchange, "exchange cannot be null"); * Configure if multiple {@link OAuth2AuthorizationRequest}s should be stored per
* session. Default is false (not allow multiple {@link OAuth2AuthorizationRequest}
return getSessionAttributes(exchange) * per session).
.doOnNext(sessionAttrs -> { * @param allowMultipleAuthorizationRequests true allows more than one
Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName); * {@link OAuth2AuthorizationRequest} to be stored per session.
* @since 5.5
if (stateToAuthzRequest == null) { */
stateToAuthzRequest = new HashMap<String, OAuth2AuthorizationRequest>(); @Deprecated
} public void setAllowMultipleAuthorizationRequests(boolean allowMultipleAuthorizationRequests) {
this.allowMultipleAuthorizationRequests = allowMultipleAuthorizationRequests;
// No matter stateToAuthzRequest was in session or not, we should always put it into session again
// in case of redis or hazelcast session. #6215
sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
}).flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
}
private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(Map<String, Object> sessionAttrs) {
return (Map<String, OAuth2AuthorizationRequest>) sessionAttrs.get(this.sessionAttributeName);
} }
} }

View File

@ -0,0 +1,252 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.server;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.DefaultServerWebExchange;
import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
import org.springframework.web.server.session.WebSessionManager;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when
* {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
* is enabled.
*
* @author Steve Riesenberg
*/
public class WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests
extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
@Before
public void setup() {
this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
this.repository.setAllowMultipleAuthorizationRequests(true);
}
@Test
public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
String oldState = "state0";
// @formatter:off
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState)
.build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
// @formatter:on
WebSessionManager sessionManager = (e) -> this.exchange.getSession();
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository
.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.loadAuthorizationRequest(oldExchange));
StepVerifier.create(saveAndSaveAndLoad)
.expectNext(oldAuthorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.expectNext(this.authorizationRequest)
.verifyComplete();
// @formatter:on
}
// gh-5145
@Test
public void loadAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenReturnOldAuthorizationRequest() {
// save 2 requests with legacy (allowMultipleAuthorizationRequests=true) and load
// with new
WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository();
legacy.setAllowMultipleAuthorizationRequests(true);
// @formatter:off
String state1 = "state-1122";
OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(state1)
.build();
StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange))
.verifyComplete();
String state2 = "state-3344";
OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(state2)
.build();
StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange))
.verifyComplete();
ServerHttpRequest newRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, state1)
.build();
ServerWebExchange newExchange = this.exchange.mutate()
.request(newRequest)
.build();
StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange))
.expectNext(authorizationRequest1)
.verifyComplete();
// @formatter:on
}
// gh-5145
@Test
public void saveAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenLoadNewAuthorizationRequest() {
// save 2 requests with legacy (allowMultipleAuthorizationRequests=true), save
// with new, and load with new
WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository();
legacy.setAllowMultipleAuthorizationRequests(true);
// @formatter:off
String state1 = "state-1122";
OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(state1)
.build();
StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange))
.verifyComplete();
String state2 = "state-3344";
OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(state2)
.build();
StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange))
.verifyComplete();
String state3 = "state-5566";
OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(state3)
.build();
ServerHttpRequest newRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, state3)
.build();
ServerWebExchange newExchange = this.exchange.mutate()
.request(newRequest)
.build();
Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository
.saveAuthorizationRequest(authorizationRequest3, this.exchange)
.then(this.repository.loadAuthorizationRequest(newExchange));
StepVerifier.create(saveAndLoad)
.expectNext(authorizationRequest3)
.verifyComplete();
// @formatter:on
}
@Test
public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() {
String oldState = "state0";
// @formatter:off
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState)
.build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
// @formatter:on
WebSessionManager sessionManager = (e) -> this.exchange.getSession();
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange))
.expectNext(oldAuthorizationRequest)
.verifyComplete();
// @formatter:on
}
// gh-7327
@Test
public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() {
String oldState = "state0";
// @formatter:off
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState)
.build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
// @formatter:on
Map<String, Object> sessionAttrs = spy(new HashMap<>());
WebSession session = mock(WebSession.class);
given(session.getAttributes()).willReturn(sessionAttrs);
WebSessionManager sessionManager = (e) -> Mono.just(session);
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
// @formatter:on
verify(sessionAttrs, times(3)).put(any(), any());
}
}

View File

@ -0,0 +1,159 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.server;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.DefaultServerWebExchange;
import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
import org.springframework.web.server.session.WebSessionManager;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when
* {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
* is disabled.
*
* @author Steve Riesenberg
*/
public class WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests
extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
@Before
public void setup() {
this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
this.repository.setAllowMultipleAuthorizationRequests(false);
}
// gh-5145
@Test
public void loadAuthorizationRequestWhenMultipleSavedThenReturnLastAuthorizationRequest() {
// @formatter:off
String state1 = "state-1122";
OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(state1)
.build();
StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest1, this.exchange))
.verifyComplete();
String state2 = "state-3344";
OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(state2)
.build();
StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest2, this.exchange))
.verifyComplete();
String state3 = "state-5566";
OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(state3)
.build();
StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest3, this.exchange))
.verifyComplete();
ServerHttpRequest newRequest1 = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, state1)
.build();
ServerWebExchange newExchange1 = this.exchange.mutate()
.request(newRequest1)
.build();
StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange1))
.verifyComplete();
ServerHttpRequest newRequest2 = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, state2)
.build();
ServerWebExchange newExchange2 = this.exchange.mutate()
.request(newRequest2)
.build();
StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange2))
.verifyComplete();
ServerHttpRequest newRequest3 = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, state3)
.build();
ServerWebExchange newExchange3 = this.exchange.mutate()
.request(newRequest3)
.build();
StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange3))
.expectNext(authorizationRequest3)
.verifyComplete();
// @formatter:on
}
// gh-5145
@Test
public void removeAuthorizationRequestWhenMultipleThenSessionAttributeRemoved() {
String oldState = "state0";
// @formatter:off
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState)
.build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
// @formatter:on
Map<String, Object> sessionAttrs = spy(new HashMap<>());
WebSession session = mock(WebSession.class);
given(session.getAttributes()).willReturn(sessionAttrs);
WebSessionManager sessionManager = (e) -> Mono.just(session);
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
// @formatter:on
verify(sessionAttrs, times(2)).put(anyString(), any());
verify(sessionAttrs).remove(anyString());
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,51 +16,39 @@
package org.springframework.security.oauth2.client.web.server; package org.springframework.security.oauth2.client.web.server;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.junit.Test; import org.junit.Test;
import org.springframework.http.codec.ServerCodecConfigurer; import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession; import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.DefaultServerWebExchange;
import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
import org.springframework.web.server.session.WebSessionManager;
import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import reactor.test.StepVerifier;
/** /**
* @author Rob Winch * @author Rob Winch
* @since 5.1 * @since 5.1
*/ */
public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests { public abstract class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
private WebSessionOAuth2ServerAuthorizationRequestRepository repository = protected WebSessionOAuth2ServerAuthorizationRequestRepository repository;
new WebSessionOAuth2ServerAuthorizationRequestRepository();
private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() // @formatter:off
protected OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize") .authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id") .clientId("client-id")
.redirectUri("http://localhost/client-1") .redirectUri("http://localhost/client-1")
.state("state") .state("state")
.build(); .build();
private ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/") protected ServerWebExchange exchange = MockServerWebExchange
.queryParam(OAuth2ParameterNames.STATE, "state")); .from(MockServerHttpRequest.get("/").queryParam(OAuth2ParameterNames.STATE, "state"));
@Test @Test
public void loadAuthorizationRequestWhenNullExchangeThenIllegalArgumentException() { public void loadAuthorizationRequestWhenNullExchangeThenIllegalArgumentException() {
@ -106,39 +94,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
.verifyComplete(); .verifyComplete();
} }
@Test
public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
String oldState = "state0";
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState).build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
WebSessionManager sessionManager = e -> this.exchange.getSession();
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.loadAuthorizationRequest(oldExchange));
StepVerifier.create(saveAndSaveAndLoad)
.expectNext(oldAuthorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.expectNext(this.authorizationRequest)
.verifyComplete();
}
@Test @Test
public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() { public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() {
this.authorizationRequest = null; this.authorizationRequest = null;
@ -203,80 +158,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
.verifyComplete(); .verifyComplete();
} }
@Test
public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() {
String oldState = "state0";
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState).build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
WebSessionManager sessionManager = e -> this.exchange.getSession();
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndSaveAndRemove)
.expectNext(this.authorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange))
.expectNext(oldAuthorizationRequest)
.verifyComplete();
}
// gh-7327
@Test
public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() {
String oldState = "state0";
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState).build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
Map<String, Object> sessionAttrs = spy(new HashMap<>());
WebSession session = mock(WebSession.class);
when(session.getAttributes()).thenReturn(sessionAttrs);
WebSessionManager sessionManager = e -> Mono.just(session);
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndSaveAndRemove)
.expectNext(this.authorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
verify(sessionAttrs, times(3)).put(any(), any());
}
private void assertSessionStartedIs(boolean expected) { private void assertSessionStartedIs(boolean expected) {
Mono<Boolean> isStarted = this.exchange.getSession().map(WebSession::isStarted); Mono<Boolean> isStarted = this.exchange.getSession().map(WebSession::isStarted);
StepVerifier.create(isStarted) StepVerifier.create(isStarted)