Request Cache supports matchingRequestParameterName

Closes gh-7157 gh-11453
This commit is contained in:
Rob Winch 2022-06-29 18:19:11 -05:00
parent 459003e1b3
commit 6510274854
8 changed files with 294 additions and 22 deletions

View File

@ -17,6 +17,7 @@
package org.springframework.security.web.jackson2; package org.springframework.security.web.jackson2;
import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
@ -43,4 +44,7 @@ import org.springframework.security.web.savedrequest.DefaultSavedRequest;
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE) @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
abstract class DefaultSavedRequestMixin { abstract class DefaultSavedRequestMixin {
@JsonInclude(JsonInclude.Include.NON_NULL)
String matchingRequestParameterName;
} }

View File

@ -97,8 +97,15 @@ public class DefaultSavedRequest implements SavedRequest {
private final int serverPort; private final int serverPort;
@SuppressWarnings("unchecked") private final String matchingRequestParameterName;
public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver) { public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver) {
this(request, portResolver, null);
}
@SuppressWarnings("unchecked")
public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver,
String matchingRequestParameterName) {
Assert.notNull(request, "Request required"); Assert.notNull(request, "Request required");
Assert.notNull(portResolver, "PortResolver required"); Assert.notNull(portResolver, "PortResolver required");
// Cookies // Cookies
@ -131,6 +138,7 @@ public class DefaultSavedRequest implements SavedRequest {
this.serverName = request.getServerName(); this.serverName = request.getServerName();
this.contextPath = request.getContextPath(); this.contextPath = request.getContextPath();
this.servletPath = request.getServletPath(); this.servletPath = request.getServletPath();
this.matchingRequestParameterName = matchingRequestParameterName;
} }
/** /**
@ -147,6 +155,7 @@ public class DefaultSavedRequest implements SavedRequest {
this.serverName = builder.serverName; this.serverName = builder.serverName;
this.servletPath = builder.servletPath; this.servletPath = builder.servletPath;
this.serverPort = builder.serverPort; this.serverPort = builder.serverPort;
this.matchingRequestParameterName = builder.matchingRequestParameterName;
} }
/** /**
@ -264,8 +273,9 @@ public class DefaultSavedRequest implements SavedRequest {
*/ */
@Override @Override
public String getRedirectUrl() { public String getRedirectUrl() {
String queryString = createQueryString(this.queryString, this.matchingRequestParameterName);
return UrlUtils.buildFullRequestUrl(this.scheme, this.serverName, this.serverPort, this.requestURI, return UrlUtils.buildFullRequestUrl(this.scheme, this.serverName, this.serverPort, this.requestURI,
this.queryString); queryString);
} }
@Override @Override
@ -353,6 +363,19 @@ public class DefaultSavedRequest implements SavedRequest {
return "DefaultSavedRequest [" + getRedirectUrl() + "]"; return "DefaultSavedRequest [" + getRedirectUrl() + "]";
} }
private static String createQueryString(String queryString, String matchingRequestParameterName) {
if (matchingRequestParameterName == null) {
return queryString;
}
if (queryString == null || queryString.length() == 0) {
return matchingRequestParameterName;
}
if (queryString.endsWith("&")) {
return queryString + matchingRequestParameterName;
}
return queryString + "&" + matchingRequestParameterName;
}
/** /**
* @since 4.2 * @since 4.2
*/ */
@ -388,6 +411,8 @@ public class DefaultSavedRequest implements SavedRequest {
private int serverPort = 80; private int serverPort = 80;
private String matchingRequestParameterName;
public Builder setCookies(List<SavedCookie> cookies) { public Builder setCookies(List<SavedCookie> cookies) {
this.cookies = cookies; this.cookies = cookies;
return this; return this;
@ -458,6 +483,11 @@ public class DefaultSavedRequest implements SavedRequest {
return this; return this;
} }
public Builder setMatchingRequestParameterName(String matchingRequestParameterName) {
this.matchingRequestParameterName = matchingRequestParameterName;
return this;
}
public DefaultSavedRequest build() { public DefaultSavedRequest build() {
DefaultSavedRequest savedRequest = new DefaultSavedRequest(this); DefaultSavedRequest savedRequest = new DefaultSavedRequest(this);
if (!ObjectUtils.isEmpty(this.cookies)) { if (!ObjectUtils.isEmpty(this.cookies)) {

View File

@ -52,6 +52,8 @@ public class HttpSessionRequestCache implements RequestCache {
private String sessionAttrName = SAVED_REQUEST; private String sessionAttrName = SAVED_REQUEST;
private String matchingRequestParameterName;
/** /**
* Stores the current request, provided the configuration properties allow it. * Stores the current request, provided the configuration properties allow it.
*/ */
@ -64,7 +66,8 @@ public class HttpSessionRequestCache implements RequestCache {
} }
return; return;
} }
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver); DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver,
this.matchingRequestParameterName);
if (this.createSessionAllowed || request.getSession(false) != null) { if (this.createSessionAllowed || request.getSession(false) != null) {
// Store the HTTP request itself. Used by // Store the HTTP request itself. Used by
// AbstractAuthenticationProcessingFilter // AbstractAuthenticationProcessingFilter
@ -96,6 +99,12 @@ public class HttpSessionRequestCache implements RequestCache {
@Override @Override
public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) { public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) {
if (this.matchingRequestParameterName != null
&& request.getParameter(this.matchingRequestParameterName) == null) {
this.logger.trace(
"matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided");
return null;
}
SavedRequest saved = getRequest(request, response); SavedRequest saved = getRequest(request, response);
if (saved == null) { if (saved == null) {
this.logger.trace("No saved request"); this.logger.trace("No saved request");
@ -161,4 +170,16 @@ public class HttpSessionRequestCache implements RequestCache {
this.sessionAttrName = sessionAttrName; this.sessionAttrName = sessionAttrName;
} }
/**
* Specify the name of a query parameter that is added to the URL that specifies the
* request cache should be checked in
* {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)}
* @param matchingRequestParameterName the parameter name that must be in the request
* for {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)} to check
* the session.
*/
public void setMatchingRequestParameterName(String matchingRequestParameterName) {
this.matchingRequestParameterName = matchingRequestParameterName;
}
} }

View File

@ -34,8 +34,10 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMat
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession; import org.springframework.web.server.WebSession;
import org.springframework.web.util.UriComponentsBuilder;
/** /**
* An implementation of {@link ServerRequestCache} that saves the * An implementation of {@link ServerRequestCache} that saves the
@ -57,6 +59,8 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
private ServerWebExchangeMatcher saveRequestMatcher = createDefaultRequestMacher(); private ServerWebExchangeMatcher saveRequestMatcher = createDefaultRequestMacher();
private String matchingRequestParameterName;
/** /**
* Sets the matcher to determine if the request should be saved. The default is to * Sets the matcher to determine if the request should be saved. The default is to
* match on any GET request. * match on any GET request.
@ -81,19 +85,53 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
public Mono<URI> getRedirectUri(ServerWebExchange exchange) { public Mono<URI> getRedirectUri(ServerWebExchange exchange) {
return exchange.getSession() return exchange.getSession()
.flatMap((session) -> Mono.justOrEmpty(session.<String>getAttribute(this.sessionAttrName))) .flatMap((session) -> Mono.justOrEmpty(session.<String>getAttribute(this.sessionAttrName)))
.map(URI::create); .map(this::createRedirectUri);
} }
@Override @Override
public Mono<ServerHttpRequest> removeMatchingRequest(ServerWebExchange exchange) { public Mono<ServerHttpRequest> removeMatchingRequest(ServerWebExchange exchange) {
MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
if (this.matchingRequestParameterName != null && !queryParams.containsKey(this.matchingRequestParameterName)) {
this.logger.trace(
"matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided");
return Mono.empty();
}
ServerHttpRequest request = stripMatchingRequestParameterName(exchange.getRequest());
return exchange.getSession().map(WebSession::getAttributes).filter((attributes) -> { return exchange.getSession().map(WebSession::getAttributes).filter((attributes) -> {
String requestPath = pathInApplication(exchange.getRequest()); String requestPath = pathInApplication(request);
boolean removed = attributes.remove(this.sessionAttrName, requestPath); boolean removed = attributes.remove(this.sessionAttrName, requestPath);
if (removed) { if (removed) {
logger.debug(LogMessage.format("Request removed from WebSession: '%s'", requestPath)); logger.debug(LogMessage.format("Request removed from WebSession: '%s'", requestPath));
} }
return removed; return removed;
}).map((attributes) -> exchange.getRequest()); }).map((attributes) -> request);
}
/**
* Specify the name of a query parameter that is added to the URL in
* {@link #getRedirectUri(ServerWebExchange)} and is required for
* {@link #removeMatchingRequest(ServerWebExchange)} to look up the
* {@link ServerHttpRequest}.
* @param matchingRequestParameterName the parameter name that must be in the request
* for {@link #removeMatchingRequest(ServerWebExchange)} to check the session.
*/
public void setMatchingRequestParameterName(String matchingRequestParameterName) {
this.matchingRequestParameterName = matchingRequestParameterName;
}
private ServerHttpRequest stripMatchingRequestParameterName(ServerHttpRequest request) {
if (this.matchingRequestParameterName == null) {
return request;
}
// @formatter:off
URI uri = UriComponentsBuilder.fromUri(request.getURI())
.replaceQueryParam(this.matchingRequestParameterName)
.build()
.toUri();
return request.mutate()
.uri(uri)
.build();
// @formatter:on
} }
private static String pathInApplication(ServerHttpRequest request) { private static String pathInApplication(ServerHttpRequest request) {
@ -102,6 +140,18 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
return path + ((query != null) ? "?" + query : ""); return path + ((query != null) ? "?" + query : "");
} }
private URI createRedirectUri(String uri) {
if (this.matchingRequestParameterName == null) {
return URI.create(uri);
}
// @formatter:off
return UriComponentsBuilder.fromUriString(uri)
.queryParam(this.matchingRequestParameterName)
.build()
.toUri();
// @formatter:on
}
private static ServerWebExchangeMatcher createDefaultRequestMacher() { private static ServerWebExchangeMatcher createDefaultRequestMacher() {
ServerWebExchangeMatcher get = ServerWebExchangeMatchers.pathMatchers(HttpMethod.GET, "/**"); ServerWebExchangeMatcher get = ServerWebExchangeMatchers.pathMatchers(HttpMethod.GET, "/**");
ServerWebExchangeMatcher notFavicon = new NegatedServerWebExchangeMatcher( ServerWebExchangeMatcher notFavicon = new NegatedServerWebExchangeMatcher(
@ -111,4 +161,17 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
return new AndServerWebExchangeMatcher(get, notFavicon, html); return new AndServerWebExchangeMatcher(get, notFavicon, html);
} }
private static String createQueryString(String queryString, String matchingRequestParameterName) {
if (matchingRequestParameterName == null) {
return queryString;
}
if (queryString == null || queryString.length() == 0) {
return matchingRequestParameterName;
}
if (queryString.endsWith("&")) {
return queryString + matchingRequestParameterName;
}
return queryString + "&" + matchingRequestParameterName;
}
} }

View File

@ -55,22 +55,42 @@ public class DefaultSavedRequestMixinTests extends AbstractMixinTests {
// @formatter:on // @formatter:on
// @formatter:off // @formatter:off
private static final String REQUEST_JSON = "{" + private static final String REQUEST_JSON = "{" +
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", " "\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
+ "\"cookies\": " + COOKIES_JSON + "," + "\"cookies\": " + COOKIES_JSON + ","
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], " + "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, " + "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"}," + "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
+ "\"contextPath\": \"\", " + "\"contextPath\": \"\", "
+ "\"method\": \"\", " + "\"method\": \"\", "
+ "\"pathInfo\": null, " + "\"pathInfo\": null, "
+ "\"queryString\": null, " + "\"queryString\": null, "
+ "\"requestURI\": \"\", " + "\"requestURI\": \"\", "
+ "\"requestURL\": \"http://localhost\", " + "\"requestURL\": \"http://localhost\", "
+ "\"scheme\": \"http\", " + "\"scheme\": \"http\", "
+ "\"serverName\": \"localhost\", " + "\"serverName\": \"localhost\", "
+ "\"servletPath\": \"\", " + "\"servletPath\": \"\", "
+ "\"serverPort\": 80" + "\"serverPort\": 80"
+ "}"; + "}";
// @formatter:on
// @formatter:off
private static final String REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON = "{" +
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
+ "\"cookies\": " + COOKIES_JSON + ","
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
+ "\"contextPath\": \"\", "
+ "\"method\": \"\", "
+ "\"pathInfo\": null, "
+ "\"queryString\": null, "
+ "\"requestURI\": \"\", "
+ "\"requestURL\": \"http://localhost\", "
+ "\"scheme\": \"http\", "
+ "\"serverName\": \"localhost\", "
+ "\"servletPath\": \"\", "
+ "\"serverPort\": 80, "
+ "\"matchingRequestParameterName\": \"success\""
+ "}";
// @formatter:on // @formatter:on
@Test @Test
public void matchRequestBuildWithConstructorAndBuilder() { public void matchRequestBuildWithConstructorAndBuilder() {
@ -125,4 +145,17 @@ public class DefaultSavedRequestMixinTests extends AbstractMixinTests {
assertThat(request.getHeaderValues("x-auth-token")).hasSize(1).contains("12"); assertThat(request.getHeaderValues("x-auth-token")).hasSize(1).contains("12");
} }
@Test
public void deserializeWhenMatchingRequestParameterNameThenRedirectUrlContainsParam() throws IOException {
DefaultSavedRequest request = (DefaultSavedRequest) this.mapper
.readValue(REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON, Object.class);
assertThat(request.getRedirectUrl()).isEqualTo("http://localhost?success");
}
@Test
public void deserializeWhenNullMatchingRequestParameterNameThenRedirectUrlDoesNotContainParam() throws IOException {
DefaultSavedRequest request = (DefaultSavedRequest) this.mapper.readValue(REQUEST_JSON, Object.class);
assertThat(request.getRedirectUrl()).isEqualTo("http://localhost");
}
} }

View File

@ -16,6 +16,8 @@
package org.springframework.security.web.savedrequest; package org.springframework.security.web.savedrequest;
import java.net.URL;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
@ -57,4 +59,67 @@ public class DefaultSavedRequestTests {
assertThat(saved.getParameterValues("anothertest")).isNull(); assertThat(saved.getParameterValues("anothertest")).isNull();
} }
@Test
public void getRedirectUrlWhenNoQueryAndDefaultMatchingRequestParameterNameThenNoQuery() throws Exception {
DefaultSavedRequest savedRequest = new DefaultSavedRequest(new MockHttpServletRequest(),
new MockPortResolver(8080, 8443));
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasNoQuery();
}
@Test
public void getRedirectUrlWhenQueryAndDefaultMatchingRequestParameterNameNullThenNoQuery() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar");
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443), null);
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("foo=bar");
}
@Test
public void getRedirectUrlWhenNoQueryAndNullMatchingRequestParameterNameThenNoQuery() throws Exception {
DefaultSavedRequest savedRequest = new DefaultSavedRequest(new MockHttpServletRequest(),
new MockPortResolver(8080, 8443), null);
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasNoQuery();
}
@Test
public void getRedirectUrlWhenNoQueryAndMatchingRequestParameterNameThenQuery() throws Exception {
DefaultSavedRequest savedRequest = new DefaultSavedRequest(new MockHttpServletRequest(),
new MockPortResolver(8080, 8443), "success");
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("success");
}
@Test
public void getRedirectUrlWhenQueryEmptyAndMatchingRequestParameterNameThenQuery() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("");
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443),
"success");
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("success");
}
@Test
public void getRedirectUrlWhenQueryEndsAmpersandAndMatchingRequestParameterNameThenQuery() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar&");
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443),
"success");
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("foo=bar&success");
}
@Test
public void getRedirectUrlWhenQueryDoesNotEndAmpersandAndMatchingRequestParameterNameThenQuery() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar");
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443),
"success");
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("foo=bar&success");
}
} }

View File

@ -31,6 +31,10 @@ import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.PortResolverImpl; import org.springframework.security.web.PortResolverImpl;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
/** /**
* @author Luke Taylor * @author Luke Taylor
@ -92,6 +96,30 @@ public class HttpSessionRequestCacheTests {
assertThat(request.getSession().getAttribute("CUSTOM_SAVED_REQUEST")).isNotNull(); assertThat(request.getSession().getAttribute("CUSTOM_SAVED_REQUEST")).isNotNull();
} }
@Test
public void getMatchingRequestWhenMatchingRequestParameterNameSetThenSessionNotAccessed() {
HttpSessionRequestCache cache = new HttpSessionRequestCache();
cache.setMatchingRequestParameterName("success");
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletRequest matchingRequest = cache.getMatchingRequest(request, new MockHttpServletResponse());
assertThat(matchingRequest).isNull();
verify(request, never()).getSession();
verify(request, never()).getSession(anyBoolean());
}
@Test
public void getMatchingRequestWhenMatchingRequestParameterNameSetAndParameterExistThenLookedUp() {
MockHttpServletRequest request = new MockHttpServletRequest();
HttpSessionRequestCache cache = new HttpSessionRequestCache();
cache.setMatchingRequestParameterName("success");
cache.saveRequest(request, new MockHttpServletResponse());
MockHttpServletRequest requestToMatch = new MockHttpServletRequest();
requestToMatch.setParameter("success", "");
requestToMatch.setSession(request.getSession());
HttpServletRequest matchingRequest = cache.getMatchingRequest(requestToMatch, new MockHttpServletResponse());
assertThat(matchingRequest).isNotNull();
}
private static final class CustomSavedRequest implements SavedRequest { private static final class CustomSavedRequest implements SavedRequest {
private final SavedRequest delegate; private final SavedRequest delegate;

View File

@ -25,8 +25,13 @@ import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
/** /**
* @author Rob Winch * @author Rob Winch
@ -96,4 +101,27 @@ public class WebSessionServerRequestCacheTests {
assertThat(this.cache.getRedirectUri(exchange).block()).isNull(); assertThat(this.cache.getRedirectUri(exchange).block()).isNull();
} }
@Test
public void removeMatchingRequestWhenNoParameter() {
this.cache.setMatchingRequestParameterName("success");
MockServerHttpRequest request = MockServerHttpRequest.get("/secured/").build();
ServerWebExchange exchange = mock(ServerWebExchange.class);
given(exchange.getRequest()).willReturn(request);
assertThat(this.cache.removeMatchingRequest(exchange).block()).isNull();
verify(exchange, never()).getSession();
}
@Test
public void removeMatchingRequestWhenParameter() {
this.cache.setMatchingRequestParameterName("success");
MockServerHttpRequest request = MockServerHttpRequest.get("/secured/").accept(MediaType.TEXT_HTML).build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
this.cache.saveRequest(exchange).block();
String redirectUri = "/secured/?success";
assertThat(this.cache.getRedirectUri(exchange).block()).isEqualTo(URI.create(redirectUri));
MockServerHttpRequest redirectRequest = MockServerHttpRequest.get(redirectUri).build();
ServerWebExchange redirectExchange = exchange.mutate().request(redirectRequest).build();
assertThat(this.cache.removeMatchingRequest(redirectExchange).block()).isNotNull();
}
} }