Request Cache supports matchingRequestParameterName

This commit is contained in:
Rob Winch 2022-06-29 18:19:11 -05:00
parent 38cb6c3172
commit 28c0d1459c
8 changed files with 294 additions and 22 deletions

View File

@ -17,6 +17,7 @@
package org.springframework.security.web.jackson2;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
@ -43,4 +44,7 @@ import org.springframework.security.web.savedrequest.DefaultSavedRequest;
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
abstract class DefaultSavedRequestMixin {
@JsonInclude(JsonInclude.Include.NON_NULL)
String matchingRequestParameterName;
}

View File

@ -98,8 +98,15 @@ public class DefaultSavedRequest implements SavedRequest {
private final int serverPort;
@SuppressWarnings("unchecked")
private final String matchingRequestParameterName;
public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver) {
this(request, portResolver, null);
}
@SuppressWarnings("unchecked")
public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver,
String matchingRequestParameterName) {
Assert.notNull(request, "Request required");
Assert.notNull(portResolver, "PortResolver required");
// Cookies
@ -132,6 +139,7 @@ public class DefaultSavedRequest implements SavedRequest {
this.serverName = request.getServerName();
this.contextPath = request.getContextPath();
this.servletPath = request.getServletPath();
this.matchingRequestParameterName = matchingRequestParameterName;
}
/**
@ -148,6 +156,7 @@ public class DefaultSavedRequest implements SavedRequest {
this.serverName = builder.serverName;
this.servletPath = builder.servletPath;
this.serverPort = builder.serverPort;
this.matchingRequestParameterName = builder.matchingRequestParameterName;
}
/**
@ -265,8 +274,9 @@ public class DefaultSavedRequest implements SavedRequest {
*/
@Override
public String getRedirectUrl() {
String queryString = createQueryString(this.queryString, this.matchingRequestParameterName);
return UrlUtils.buildFullRequestUrl(this.scheme, this.serverName, this.serverPort, this.requestURI,
this.queryString);
queryString);
}
@Override
@ -354,6 +364,19 @@ public class DefaultSavedRequest implements SavedRequest {
return "DefaultSavedRequest [" + getRedirectUrl() + "]";
}
private static String createQueryString(String queryString, String matchingRequestParameterName) {
if (matchingRequestParameterName == null) {
return queryString;
}
if (queryString == null || queryString.length() == 0) {
return matchingRequestParameterName;
}
if (queryString.endsWith("&")) {
return queryString + matchingRequestParameterName;
}
return queryString + "&" + matchingRequestParameterName;
}
/**
* @since 4.2
*/
@ -389,6 +412,8 @@ public class DefaultSavedRequest implements SavedRequest {
private int serverPort = 80;
private String matchingRequestParameterName;
public Builder setCookies(List<SavedCookie> cookies) {
this.cookies = cookies;
return this;
@ -459,6 +484,11 @@ public class DefaultSavedRequest implements SavedRequest {
return this;
}
public Builder setMatchingRequestParameterName(String matchingRequestParameterName) {
this.matchingRequestParameterName = matchingRequestParameterName;
return this;
}
public DefaultSavedRequest build() {
DefaultSavedRequest savedRequest = new DefaultSavedRequest(this);
if (!ObjectUtils.isEmpty(this.cookies)) {

View File

@ -53,6 +53,8 @@ public class HttpSessionRequestCache implements RequestCache {
private String sessionAttrName = SAVED_REQUEST;
private String matchingRequestParameterName;
/**
* Stores the current request, provided the configuration properties allow it.
*/
@ -65,7 +67,8 @@ public class HttpSessionRequestCache implements RequestCache {
}
return;
}
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver);
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver,
this.matchingRequestParameterName);
if (this.createSessionAllowed || request.getSession(false) != null) {
// Store the HTTP request itself. Used by
// AbstractAuthenticationProcessingFilter
@ -97,6 +100,12 @@ public class HttpSessionRequestCache implements RequestCache {
@Override
public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) {
if (this.matchingRequestParameterName != null
&& request.getParameter(this.matchingRequestParameterName) == null) {
this.logger.trace(
"matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided");
return null;
}
SavedRequest saved = getRequest(request, response);
if (saved == null) {
this.logger.trace("No saved request");
@ -162,4 +171,16 @@ public class HttpSessionRequestCache implements RequestCache {
this.sessionAttrName = sessionAttrName;
}
/**
* Specify the name of a query parameter that is added to the URL that specifies the
* request cache should be checked in
* {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)}
* @param matchingRequestParameterName the parameter name that must be in the request
* for {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)} to check
* the session.
*/
public void setMatchingRequestParameterName(String matchingRequestParameterName) {
this.matchingRequestParameterName = matchingRequestParameterName;
}
}

View File

@ -34,8 +34,10 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMat
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import org.springframework.web.util.UriComponentsBuilder;
/**
* An implementation of {@link ServerRequestCache} that saves the
@ -57,6 +59,8 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
private ServerWebExchangeMatcher saveRequestMatcher = createDefaultRequestMacher();
private String matchingRequestParameterName;
/**
* Sets the matcher to determine if the request should be saved. The default is to
* match on any GET request.
@ -81,19 +85,53 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
public Mono<URI> getRedirectUri(ServerWebExchange exchange) {
return exchange.getSession()
.flatMap((session) -> Mono.justOrEmpty(session.<String>getAttribute(this.sessionAttrName)))
.map(URI::create);
.map(this::createRedirectUri);
}
@Override
public Mono<ServerHttpRequest> removeMatchingRequest(ServerWebExchange exchange) {
MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
if (this.matchingRequestParameterName != null && !queryParams.containsKey(this.matchingRequestParameterName)) {
this.logger.trace(
"matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided");
return Mono.empty();
}
ServerHttpRequest request = stripMatchingRequestParameterName(exchange.getRequest());
return exchange.getSession().map(WebSession::getAttributes).filter((attributes) -> {
String requestPath = pathInApplication(exchange.getRequest());
String requestPath = pathInApplication(request);
boolean removed = attributes.remove(this.sessionAttrName, requestPath);
if (removed) {
logger.debug(LogMessage.format("Request removed from WebSession: '%s'", requestPath));
}
return removed;
}).map((attributes) -> exchange.getRequest());
}).map((attributes) -> request);
}
/**
* Specify the name of a query parameter that is added to the URL in
* {@link #getRedirectUri(ServerWebExchange)} and is required for
* {@link #removeMatchingRequest(ServerWebExchange)} to look up the
* {@link ServerHttpRequest}.
* @param matchingRequestParameterName the parameter name that must be in the request
* for {@link #removeMatchingRequest(ServerWebExchange)} to check the session.
*/
public void setMatchingRequestParameterName(String matchingRequestParameterName) {
this.matchingRequestParameterName = matchingRequestParameterName;
}
private ServerHttpRequest stripMatchingRequestParameterName(ServerHttpRequest request) {
if (this.matchingRequestParameterName == null) {
return request;
}
// @formatter:off
URI uri = UriComponentsBuilder.fromUri(request.getURI())
.replaceQueryParam(this.matchingRequestParameterName)
.build()
.toUri();
return request.mutate()
.uri(uri)
.build();
// @formatter:on
}
private static String pathInApplication(ServerHttpRequest request) {
@ -102,6 +140,18 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
return path + ((query != null) ? "?" + query : "");
}
private URI createRedirectUri(String uri) {
if (this.matchingRequestParameterName == null) {
return URI.create(uri);
}
// @formatter:off
return UriComponentsBuilder.fromUriString(uri)
.queryParam(this.matchingRequestParameterName)
.build()
.toUri();
// @formatter:on
}
private static ServerWebExchangeMatcher createDefaultRequestMacher() {
ServerWebExchangeMatcher get = ServerWebExchangeMatchers.pathMatchers(HttpMethod.GET, "/**");
ServerWebExchangeMatcher notFavicon = new NegatedServerWebExchangeMatcher(
@ -111,4 +161,17 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
return new AndServerWebExchangeMatcher(get, notFavicon, html);
}
private static String createQueryString(String queryString, String matchingRequestParameterName) {
if (matchingRequestParameterName == null) {
return queryString;
}
if (queryString == null || queryString.length() == 0) {
return matchingRequestParameterName;
}
if (queryString.endsWith("&")) {
return queryString + matchingRequestParameterName;
}
return queryString + "&" + matchingRequestParameterName;
}
}

View File

@ -56,22 +56,42 @@ public class DefaultSavedRequestMixinTests extends AbstractMixinTests {
// @formatter:on
// @formatter:off
private static final String REQUEST_JSON = "{" +
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
+ "\"cookies\": " + COOKIES_JSON + ","
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
+ "\"contextPath\": \"\", "
+ "\"method\": \"\", "
+ "\"pathInfo\": null, "
+ "\"queryString\": null, "
+ "\"requestURI\": \"\", "
+ "\"requestURL\": \"http://localhost\", "
+ "\"scheme\": \"http\", "
+ "\"serverName\": \"localhost\", "
+ "\"servletPath\": \"\", "
+ "\"serverPort\": 80"
+ "}";
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
+ "\"cookies\": " + COOKIES_JSON + ","
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
+ "\"contextPath\": \"\", "
+ "\"method\": \"\", "
+ "\"pathInfo\": null, "
+ "\"queryString\": null, "
+ "\"requestURI\": \"\", "
+ "\"requestURL\": \"http://localhost\", "
+ "\"scheme\": \"http\", "
+ "\"serverName\": \"localhost\", "
+ "\"servletPath\": \"\", "
+ "\"serverPort\": 80"
+ "}";
// @formatter:on
// @formatter:off
private static final String REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON = "{" +
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
+ "\"cookies\": " + COOKIES_JSON + ","
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
+ "\"contextPath\": \"\", "
+ "\"method\": \"\", "
+ "\"pathInfo\": null, "
+ "\"queryString\": null, "
+ "\"requestURI\": \"\", "
+ "\"requestURL\": \"http://localhost\", "
+ "\"scheme\": \"http\", "
+ "\"serverName\": \"localhost\", "
+ "\"servletPath\": \"\", "
+ "\"serverPort\": 80, "
+ "\"matchingRequestParameterName\": \"success\""
+ "}";
// @formatter:on
@Test
public void matchRequestBuildWithConstructorAndBuilder() {
@ -126,4 +146,17 @@ public class DefaultSavedRequestMixinTests extends AbstractMixinTests {
assertThat(request.getHeaderValues("x-auth-token")).hasSize(1).contains("12");
}
@Test
public void deserializeWhenMatchingRequestParameterNameThenRedirectUrlContainsParam() throws IOException {
DefaultSavedRequest request = (DefaultSavedRequest) this.mapper
.readValue(REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON, Object.class);
assertThat(request.getRedirectUrl()).isEqualTo("http://localhost?success");
}
@Test
public void deserializeWhenNullMatchingRequestParameterNameThenRedirectUrlDoesNotContainParam() throws IOException {
DefaultSavedRequest request = (DefaultSavedRequest) this.mapper.readValue(REQUEST_JSON, Object.class);
assertThat(request.getRedirectUrl()).isEqualTo("http://localhost");
}
}

View File

@ -16,6 +16,8 @@
package org.springframework.security.web.savedrequest;
import java.net.URL;
import org.junit.jupiter.api.Test;
import org.springframework.mock.web.MockHttpServletRequest;
@ -57,4 +59,67 @@ public class DefaultSavedRequestTests {
assertThat(saved.getParameterValues("anothertest")).isNull();
}
@Test
public void getRedirectUrlWhenNoQueryAndDefaultMatchingRequestParameterNameThenNoQuery() throws Exception {
DefaultSavedRequest savedRequest = new DefaultSavedRequest(new MockHttpServletRequest(),
new MockPortResolver(8080, 8443));
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasNoQuery();
}
@Test
public void getRedirectUrlWhenQueryAndDefaultMatchingRequestParameterNameNullThenNoQuery() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar");
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443), null);
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("foo=bar");
}
@Test
public void getRedirectUrlWhenNoQueryAndNullMatchingRequestParameterNameThenNoQuery() throws Exception {
DefaultSavedRequest savedRequest = new DefaultSavedRequest(new MockHttpServletRequest(),
new MockPortResolver(8080, 8443), null);
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasNoQuery();
}
@Test
public void getRedirectUrlWhenNoQueryAndMatchingRequestParameterNameThenQuery() throws Exception {
DefaultSavedRequest savedRequest = new DefaultSavedRequest(new MockHttpServletRequest(),
new MockPortResolver(8080, 8443), "success");
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("success");
}
@Test
public void getRedirectUrlWhenQueryEmptyAndMatchingRequestParameterNameThenQuery() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("");
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443),
"success");
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("success");
}
@Test
public void getRedirectUrlWhenQueryEndsAmpersandAndMatchingRequestParameterNameThenQuery() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar&");
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443),
"success");
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("foo=bar&success");
}
@Test
public void getRedirectUrlWhenQueryDoesNotEndAmpersandAndMatchingRequestParameterNameThenQuery() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setQueryString("foo=bar");
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443),
"success");
assertThat(savedRequest.getParameterMap()).doesNotContainKey("success");
assertThat(new URL(savedRequest.getRedirectUrl())).hasQuery("foo=bar&success");
}
}

View File

@ -32,6 +32,10 @@ import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.PortResolverImpl;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
/**
* @author Luke Taylor
@ -93,6 +97,30 @@ public class HttpSessionRequestCacheTests {
assertThat(request.getSession().getAttribute("CUSTOM_SAVED_REQUEST")).isNotNull();
}
@Test
public void getMatchingRequestWhenMatchingRequestParameterNameSetThenSessionNotAccessed() {
HttpSessionRequestCache cache = new HttpSessionRequestCache();
cache.setMatchingRequestParameterName("success");
HttpServletRequest request = mock(HttpServletRequest.class);
HttpServletRequest matchingRequest = cache.getMatchingRequest(request, new MockHttpServletResponse());
assertThat(matchingRequest).isNull();
verify(request, never()).getSession();
verify(request, never()).getSession(anyBoolean());
}
@Test
public void getMatchingRequestWhenMatchingRequestParameterNameSetAndParameterExistThenLookedUp() {
MockHttpServletRequest request = new MockHttpServletRequest();
HttpSessionRequestCache cache = new HttpSessionRequestCache();
cache.setMatchingRequestParameterName("success");
cache.saveRequest(request, new MockHttpServletResponse());
MockHttpServletRequest requestToMatch = new MockHttpServletRequest();
requestToMatch.setParameter("success", "");
requestToMatch.setSession(request.getSession());
HttpServletRequest matchingRequest = cache.getMatchingRequest(requestToMatch, new MockHttpServletResponse());
assertThat(matchingRequest).isNotNull();
}
private static final class CustomSavedRequest implements SavedRequest {
private final SavedRequest delegate;

View File

@ -25,8 +25,13 @@ import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
/**
* @author Rob Winch
@ -96,4 +101,27 @@ public class WebSessionServerRequestCacheTests {
assertThat(this.cache.getRedirectUri(exchange).block()).isNull();
}
@Test
public void removeMatchingRequestWhenNoParameter() {
this.cache.setMatchingRequestParameterName("success");
MockServerHttpRequest request = MockServerHttpRequest.get("/secured/").build();
ServerWebExchange exchange = mock(ServerWebExchange.class);
given(exchange.getRequest()).willReturn(request);
assertThat(this.cache.removeMatchingRequest(exchange).block()).isNull();
verify(exchange, never()).getSession();
}
@Test
public void removeMatchingRequestWhenParameter() {
this.cache.setMatchingRequestParameterName("success");
MockServerHttpRequest request = MockServerHttpRequest.get("/secured/").accept(MediaType.TEXT_HTML).build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
this.cache.saveRequest(exchange).block();
String redirectUri = "/secured/?success";
assertThat(this.cache.getRedirectUri(exchange).block()).isEqualTo(URI.create(redirectUri));
MockServerHttpRequest redirectRequest = MockServerHttpRequest.get(redirectUri).build();
ServerWebExchange redirectExchange = exchange.mutate().request(redirectRequest).build();
assertThat(this.cache.removeMatchingRequest(redirectExchange).block()).isNotNull();
}
}