From 651027485462e36d6760ee113715daed49671e4f Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Wed, 29 Jun 2022 18:19:11 -0500 Subject: [PATCH] Request Cache supports matchingRequestParameterName Closes gh-7157 gh-11453 --- .../jackson2/DefaultSavedRequestMixin.java | 4 ++ .../web/savedrequest/DefaultSavedRequest.java | 34 ++++++++- .../savedrequest/HttpSessionRequestCache.java | 23 ++++++- .../WebSessionServerRequestCache.java | 69 ++++++++++++++++++- .../DefaultSavedRequestMixinTests.java | 65 ++++++++++++----- .../DefaultSavedRequestTests.java | 65 +++++++++++++++++ .../HttpSessionRequestCacheTests.java | 28 ++++++++ .../WebSessionServerRequestCacheTests.java | 28 ++++++++ 8 files changed, 294 insertions(+), 22 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixin.java b/web/src/main/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixin.java index 405ad59cdc..db055a10bc 100644 --- a/web/src/main/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixin.java +++ b/web/src/main/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixin.java @@ -17,6 +17,7 @@ package org.springframework.security.web.jackson2; import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; @@ -43,4 +44,7 @@ import org.springframework.security.web.savedrequest.DefaultSavedRequest; @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE) abstract class DefaultSavedRequestMixin { + @JsonInclude(JsonInclude.Include.NON_NULL) + String matchingRequestParameterName; + } diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java index 83c023a932..adcbe94300 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java @@ -97,8 +97,15 @@ public class DefaultSavedRequest implements SavedRequest { private final int serverPort; - @SuppressWarnings("unchecked") + private final String matchingRequestParameterName; + public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver) { + this(request, portResolver, null); + } + + @SuppressWarnings("unchecked") + public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver, + String matchingRequestParameterName) { Assert.notNull(request, "Request required"); Assert.notNull(portResolver, "PortResolver required"); // Cookies @@ -131,6 +138,7 @@ public class DefaultSavedRequest implements SavedRequest { this.serverName = request.getServerName(); this.contextPath = request.getContextPath(); this.servletPath = request.getServletPath(); + this.matchingRequestParameterName = matchingRequestParameterName; } /** @@ -147,6 +155,7 @@ public class DefaultSavedRequest implements SavedRequest { this.serverName = builder.serverName; this.servletPath = builder.servletPath; this.serverPort = builder.serverPort; + this.matchingRequestParameterName = builder.matchingRequestParameterName; } /** @@ -264,8 +273,9 @@ public class DefaultSavedRequest implements SavedRequest { */ @Override public String getRedirectUrl() { + String queryString = createQueryString(this.queryString, this.matchingRequestParameterName); return UrlUtils.buildFullRequestUrl(this.scheme, this.serverName, this.serverPort, this.requestURI, - this.queryString); + queryString); } @Override @@ -353,6 +363,19 @@ public class DefaultSavedRequest implements SavedRequest { return "DefaultSavedRequest [" + getRedirectUrl() + "]"; } + private static String createQueryString(String queryString, String matchingRequestParameterName) { + if (matchingRequestParameterName == null) { + return queryString; + } + if (queryString == null || queryString.length() == 0) { + return matchingRequestParameterName; + } + if (queryString.endsWith("&")) { + return queryString + matchingRequestParameterName; + } + return queryString + "&" + matchingRequestParameterName; + } + /** * @since 4.2 */ @@ -388,6 +411,8 @@ public class DefaultSavedRequest implements SavedRequest { private int serverPort = 80; + private String matchingRequestParameterName; + public Builder setCookies(List cookies) { this.cookies = cookies; return this; @@ -458,6 +483,11 @@ public class DefaultSavedRequest implements SavedRequest { return this; } + public Builder setMatchingRequestParameterName(String matchingRequestParameterName) { + this.matchingRequestParameterName = matchingRequestParameterName; + return this; + } + public DefaultSavedRequest build() { DefaultSavedRequest savedRequest = new DefaultSavedRequest(this); if (!ObjectUtils.isEmpty(this.cookies)) { diff --git a/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java b/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java index 03c294cbb2..2bec59eb32 100644 --- a/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java +++ b/web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java @@ -52,6 +52,8 @@ public class HttpSessionRequestCache implements RequestCache { private String sessionAttrName = SAVED_REQUEST; + private String matchingRequestParameterName; + /** * Stores the current request, provided the configuration properties allow it. */ @@ -64,7 +66,8 @@ public class HttpSessionRequestCache implements RequestCache { } return; } - DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver); + DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver, + this.matchingRequestParameterName); if (this.createSessionAllowed || request.getSession(false) != null) { // Store the HTTP request itself. Used by // AbstractAuthenticationProcessingFilter @@ -96,6 +99,12 @@ public class HttpSessionRequestCache implements RequestCache { @Override public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) { + if (this.matchingRequestParameterName != null + && request.getParameter(this.matchingRequestParameterName) == null) { + this.logger.trace( + "matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided"); + return null; + } SavedRequest saved = getRequest(request, response); if (saved == null) { this.logger.trace("No saved request"); @@ -161,4 +170,16 @@ public class HttpSessionRequestCache implements RequestCache { this.sessionAttrName = sessionAttrName; } + /** + * Specify the name of a query parameter that is added to the URL that specifies the + * request cache should be checked in + * {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)} + * @param matchingRequestParameterName the parameter name that must be in the request + * for {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)} to check + * the session. + */ + public void setMatchingRequestParameterName(String matchingRequestParameterName) { + this.matchingRequestParameterName = matchingRequestParameterName; + } + } diff --git a/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java b/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java index df6374cf5b..427e532c6f 100644 --- a/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java +++ b/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java @@ -34,8 +34,10 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMat import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; +import org.springframework.web.util.UriComponentsBuilder; /** * An implementation of {@link ServerRequestCache} that saves the @@ -57,6 +59,8 @@ public class WebSessionServerRequestCache implements ServerRequestCache { private ServerWebExchangeMatcher saveRequestMatcher = createDefaultRequestMacher(); + private String matchingRequestParameterName; + /** * Sets the matcher to determine if the request should be saved. The default is to * match on any GET request. @@ -81,19 +85,53 @@ public class WebSessionServerRequestCache implements ServerRequestCache { public Mono getRedirectUri(ServerWebExchange exchange) { return exchange.getSession() .flatMap((session) -> Mono.justOrEmpty(session.getAttribute(this.sessionAttrName))) - .map(URI::create); + .map(this::createRedirectUri); } @Override public Mono removeMatchingRequest(ServerWebExchange exchange) { + MultiValueMap queryParams = exchange.getRequest().getQueryParams(); + if (this.matchingRequestParameterName != null && !queryParams.containsKey(this.matchingRequestParameterName)) { + this.logger.trace( + "matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided"); + return Mono.empty(); + } + ServerHttpRequest request = stripMatchingRequestParameterName(exchange.getRequest()); return exchange.getSession().map(WebSession::getAttributes).filter((attributes) -> { - String requestPath = pathInApplication(exchange.getRequest()); + String requestPath = pathInApplication(request); boolean removed = attributes.remove(this.sessionAttrName, requestPath); if (removed) { logger.debug(LogMessage.format("Request removed from WebSession: '%s'", requestPath)); } return removed; - }).map((attributes) -> exchange.getRequest()); + }).map((attributes) -> request); + } + + /** + * Specify the name of a query parameter that is added to the URL in + * {@link #getRedirectUri(ServerWebExchange)} and is required for + * {@link #removeMatchingRequest(ServerWebExchange)} to look up the + * {@link ServerHttpRequest}. + * @param matchingRequestParameterName the parameter name that must be in the request + * for {@link #removeMatchingRequest(ServerWebExchange)} to check the session. + */ + public void setMatchingRequestParameterName(String matchingRequestParameterName) { + this.matchingRequestParameterName = matchingRequestParameterName; + } + + private ServerHttpRequest stripMatchingRequestParameterName(ServerHttpRequest request) { + if (this.matchingRequestParameterName == null) { + return request; + } + // @formatter:off + URI uri = UriComponentsBuilder.fromUri(request.getURI()) + .replaceQueryParam(this.matchingRequestParameterName) + .build() + .toUri(); + return request.mutate() + .uri(uri) + .build(); + // @formatter:on } private static String pathInApplication(ServerHttpRequest request) { @@ -102,6 +140,18 @@ public class WebSessionServerRequestCache implements ServerRequestCache { return path + ((query != null) ? "?" + query : ""); } + private URI createRedirectUri(String uri) { + if (this.matchingRequestParameterName == null) { + return URI.create(uri); + } + // @formatter:off + return UriComponentsBuilder.fromUriString(uri) + .queryParam(this.matchingRequestParameterName) + .build() + .toUri(); + // @formatter:on + } + private static ServerWebExchangeMatcher createDefaultRequestMacher() { ServerWebExchangeMatcher get = ServerWebExchangeMatchers.pathMatchers(HttpMethod.GET, "/**"); ServerWebExchangeMatcher notFavicon = new NegatedServerWebExchangeMatcher( @@ -111,4 +161,17 @@ public class WebSessionServerRequestCache implements ServerRequestCache { return new AndServerWebExchangeMatcher(get, notFavicon, html); } + private static String createQueryString(String queryString, String matchingRequestParameterName) { + if (matchingRequestParameterName == null) { + return queryString; + } + if (queryString == null || queryString.length() == 0) { + return matchingRequestParameterName; + } + if (queryString.endsWith("&")) { + return queryString + matchingRequestParameterName; + } + return queryString + "&" + matchingRequestParameterName; + } + } diff --git a/web/src/test/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixinTests.java b/web/src/test/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixinTests.java index 739ba98c2e..0a68e96efe 100644 --- a/web/src/test/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixinTests.java +++ b/web/src/test/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixinTests.java @@ -55,22 +55,42 @@ public class DefaultSavedRequestMixinTests extends AbstractMixinTests { // @formatter:on // @formatter:off private static final String REQUEST_JSON = "{" + - "\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", " - + "\"cookies\": " + COOKIES_JSON + "," - + "\"locales\": [\"java.util.ArrayList\", [\"en\"]], " - + "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, " - + "\"parameters\": {\"@class\": \"java.util.TreeMap\"}," - + "\"contextPath\": \"\", " - + "\"method\": \"\", " - + "\"pathInfo\": null, " - + "\"queryString\": null, " - + "\"requestURI\": \"\", " - + "\"requestURL\": \"http://localhost\", " - + "\"scheme\": \"http\", " - + "\"serverName\": \"localhost\", " - + "\"servletPath\": \"\", " - + "\"serverPort\": 80" - + "}"; + "\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", " + + "\"cookies\": " + COOKIES_JSON + "," + + "\"locales\": [\"java.util.ArrayList\", [\"en\"]], " + + "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, " + + "\"parameters\": {\"@class\": \"java.util.TreeMap\"}," + + "\"contextPath\": \"\", " + + "\"method\": \"\", " + + "\"pathInfo\": null, " + + "\"queryString\": null, " + + "\"requestURI\": \"\", " + + "\"requestURL\": \"http://localhost\", " + + "\"scheme\": \"http\", " + + "\"serverName\": \"localhost\", " + + "\"servletPath\": \"\", " + + "\"serverPort\": 80" + + "}"; + // @formatter:on + // @formatter:off + private static final String REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON = "{" + + "\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", " + + "\"cookies\": " + COOKIES_JSON + "," + + "\"locales\": [\"java.util.ArrayList\", [\"en\"]], " + + "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, " + + "\"parameters\": {\"@class\": \"java.util.TreeMap\"}," + + "\"contextPath\": \"\", " + + "\"method\": \"\", " + + "\"pathInfo\": null, " + + "\"queryString\": null, " + + "\"requestURI\": \"\", " + + "\"requestURL\": \"http://localhost\", " + + "\"scheme\": \"http\", " + + "\"serverName\": \"localhost\", " + + "\"servletPath\": \"\", " + + "\"serverPort\": 80, " + + "\"matchingRequestParameterName\": \"success\"" + + "}"; // @formatter:on @Test public void matchRequestBuildWithConstructorAndBuilder() { @@ -125,4 +145,17 @@ public class DefaultSavedRequestMixinTests extends AbstractMixinTests { assertThat(request.getHeaderValues("x-auth-token")).hasSize(1).contains("12"); } + @Test + public void deserializeWhenMatchingRequestParameterNameThenRedirectUrlContainsParam() throws IOException { + DefaultSavedRequest request = (DefaultSavedRequest) this.mapper + .readValue(REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON, Object.class); + assertThat(request.getRedirectUrl()).isEqualTo("http://localhost?success"); + } + + @Test + public void deserializeWhenNullMatchingRequestParameterNameThenRedirectUrlDoesNotContainParam() throws IOException { + DefaultSavedRequest request = (DefaultSavedRequest) this.mapper.readValue(REQUEST_JSON, Object.class); + assertThat(request.getRedirectUrl()).isEqualTo("http://localhost"); + } + } diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java index 801e828f75..4bec0f1296 100644 --- a/web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java +++ b/web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java @@ -16,6 +16,8 @@ package org.springframework.security.web.savedrequest; +import java.net.URL; + import org.junit.jupiter.api.Test; import org.springframework.mock.web.MockHttpServletRequest; @@ -57,4 +59,67 @@ public class DefaultSavedRequestTests { assertThat(saved.getParameterValues("anothertest")).isNull(); } + @Test + public void getRedirectUrlWhenNoQueryAndDefaultMatchingRequestParameterNameThenNoQuery() throws Exception { + DefaultSavedRequest savedRequest = new DefaultSavedRequest(new MockHttpServletRequest(), + new MockPortResolver(8080, 8443)); + assertThat(savedRequest.getParameterMap()).doesNotContainKey("success"); + assertThat(new URL(savedRequest.getRedirectUrl())).hasNoQuery(); + } + + @Test + public void getRedirectUrlWhenQueryAndDefaultMatchingRequestParameterNameNullThenNoQuery() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setQueryString("foo=bar"); + DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443), null); + assertThat(savedRequest.getParameterMap()).doesNotContainKey("success"); + assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("foo=bar"); + } + + @Test + public void getRedirectUrlWhenNoQueryAndNullMatchingRequestParameterNameThenNoQuery() throws Exception { + DefaultSavedRequest savedRequest = new DefaultSavedRequest(new MockHttpServletRequest(), + new MockPortResolver(8080, 8443), null); + assertThat(savedRequest.getParameterMap()).doesNotContainKey("success"); + assertThat(new URL(savedRequest.getRedirectUrl())).hasNoQuery(); + } + + @Test + public void getRedirectUrlWhenNoQueryAndMatchingRequestParameterNameThenQuery() throws Exception { + DefaultSavedRequest savedRequest = new DefaultSavedRequest(new MockHttpServletRequest(), + new MockPortResolver(8080, 8443), "success"); + assertThat(savedRequest.getParameterMap()).doesNotContainKey("success"); + assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("success"); + } + + @Test + public void getRedirectUrlWhenQueryEmptyAndMatchingRequestParameterNameThenQuery() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setQueryString(""); + DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443), + "success"); + assertThat(savedRequest.getParameterMap()).doesNotContainKey("success"); + assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("success"); + } + + @Test + public void getRedirectUrlWhenQueryEndsAmpersandAndMatchingRequestParameterNameThenQuery() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setQueryString("foo=bar&"); + DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443), + "success"); + assertThat(savedRequest.getParameterMap()).doesNotContainKey("success"); + assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("foo=bar&success"); + } + + @Test + public void getRedirectUrlWhenQueryDoesNotEndAmpersandAndMatchingRequestParameterNameThenQuery() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setQueryString("foo=bar"); + DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443), + "success"); + assertThat(savedRequest.getParameterMap()).doesNotContainKey("success"); + assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("foo=bar&success"); + } + } diff --git a/web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java b/web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java index fac615d759..1d0c8b48f4 100644 --- a/web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java +++ b/web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java @@ -31,6 +31,10 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.PortResolverImpl; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; /** * @author Luke Taylor @@ -92,6 +96,30 @@ public class HttpSessionRequestCacheTests { assertThat(request.getSession().getAttribute("CUSTOM_SAVED_REQUEST")).isNotNull(); } + @Test + public void getMatchingRequestWhenMatchingRequestParameterNameSetThenSessionNotAccessed() { + HttpSessionRequestCache cache = new HttpSessionRequestCache(); + cache.setMatchingRequestParameterName("success"); + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletRequest matchingRequest = cache.getMatchingRequest(request, new MockHttpServletResponse()); + assertThat(matchingRequest).isNull(); + verify(request, never()).getSession(); + verify(request, never()).getSession(anyBoolean()); + } + + @Test + public void getMatchingRequestWhenMatchingRequestParameterNameSetAndParameterExistThenLookedUp() { + MockHttpServletRequest request = new MockHttpServletRequest(); + HttpSessionRequestCache cache = new HttpSessionRequestCache(); + cache.setMatchingRequestParameterName("success"); + cache.saveRequest(request, new MockHttpServletResponse()); + MockHttpServletRequest requestToMatch = new MockHttpServletRequest(); + requestToMatch.setParameter("success", ""); + requestToMatch.setSession(request.getSession()); + HttpServletRequest matchingRequest = cache.getMatchingRequest(requestToMatch, new MockHttpServletResponse()); + assertThat(matchingRequest).isNotNull(); + } + private static final class CustomSavedRequest implements SavedRequest { private final SavedRequest delegate; diff --git a/web/src/test/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCacheTests.java b/web/src/test/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCacheTests.java index 020ce2a0c8..4fb099f79f 100644 --- a/web/src/test/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCacheTests.java +++ b/web/src/test/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCacheTests.java @@ -25,8 +25,13 @@ import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.web.server.ServerWebExchange; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; /** * @author Rob Winch @@ -96,4 +101,27 @@ public class WebSessionServerRequestCacheTests { assertThat(this.cache.getRedirectUri(exchange).block()).isNull(); } + @Test + public void removeMatchingRequestWhenNoParameter() { + this.cache.setMatchingRequestParameterName("success"); + MockServerHttpRequest request = MockServerHttpRequest.get("/secured/").build(); + ServerWebExchange exchange = mock(ServerWebExchange.class); + given(exchange.getRequest()).willReturn(request); + assertThat(this.cache.removeMatchingRequest(exchange).block()).isNull(); + verify(exchange, never()).getSession(); + } + + @Test + public void removeMatchingRequestWhenParameter() { + this.cache.setMatchingRequestParameterName("success"); + MockServerHttpRequest request = MockServerHttpRequest.get("/secured/").accept(MediaType.TEXT_HTML).build(); + ServerWebExchange exchange = MockServerWebExchange.from(request); + this.cache.saveRequest(exchange).block(); + String redirectUri = "/secured/?success"; + assertThat(this.cache.getRedirectUri(exchange).block()).isEqualTo(URI.create(redirectUri)); + MockServerHttpRequest redirectRequest = MockServerHttpRequest.get(redirectUri).build(); + ServerWebExchange redirectExchange = exchange.mutate().request(redirectRequest).build(); + assertThat(this.cache.removeMatchingRequest(redirectExchange).block()).isNotNull(); + } + }