diff --git a/web/src/main/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall.java b/web/src/main/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall.java index e56c6160f7..0916f1ffd9 100644 --- a/web/src/main/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall.java +++ b/web/src/main/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall.java @@ -16,6 +16,8 @@ package org.springframework.security.web.server.firewall; +import java.net.InetSocketAddress; +import java.net.URI; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -23,6 +25,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Consumer; import java.util.function.Predicate; import java.util.regex.Pattern; @@ -33,6 +36,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequestDecorator; import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.http.server.reactive.SslInfo; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.web.server.ServerWebExchange; @@ -743,6 +747,11 @@ public class StrictServerWebExchangeFirewall implements ServerWebExchangeFirewal return queryParams; } + @Override + public Builder mutate() { + return new StrictFirewallBuilder(super.mutate()); + } + private final class StrictFirewallHttpHeaders extends HttpHeaders { private StrictFirewallHttpHeaders(HttpHeaders delegate) { @@ -783,6 +792,61 @@ public class StrictServerWebExchangeFirewall implements ServerWebExchangeFirewal } + private final class StrictFirewallBuilder implements Builder { + + private final Builder delegate; + + private StrictFirewallBuilder(Builder delegate) { + this.delegate = delegate; + } + + @Override + public Builder method(HttpMethod httpMethod) { + return this.delegate.method(httpMethod); + } + + @Override + public Builder uri(URI uri) { + return this.delegate.uri(uri); + } + + @Override + public Builder path(String path) { + return this.delegate.path(path); + } + + @Override + public Builder contextPath(String contextPath) { + return this.delegate.contextPath(contextPath); + } + + @Override + public Builder header(String headerName, String... headerValues) { + return this.delegate.header(headerName, headerValues); + } + + @Override + public Builder headers(Consumer headersConsumer) { + return this.delegate.headers(headersConsumer); + } + + @Override + public Builder sslInfo(SslInfo sslInfo) { + return this.delegate.sslInfo(sslInfo); + } + + @Override + public Builder remoteAddress(InetSocketAddress remoteAddress) { + return this.delegate.remoteAddress(remoteAddress); + } + + @Override + public ServerHttpRequest build() { + return new StrictFirewallHttpRequest(this.delegate.build()); + } + + } + } } diff --git a/web/src/test/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewallTests.java b/web/src/test/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewallTests.java index b4eb293993..b8803bc0d1 100644 --- a/web/src/test/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewallTests.java +++ b/web/src/test/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewallTests.java @@ -513,4 +513,25 @@ class StrictServerWebExchangeFirewallTests { assertThat(exchange.getRequest().getHeaders().get(null)).isNull(); } + @Test + void getFirewalledExchangeWhenMutateThenHeadersStillFirewalled() { + String invalidHeaderName = "bad name"; + this.firewall.setAllowedHeaderNames((name) -> !name.equals(invalidHeaderName)); + ServerWebExchange exchange = getFirewalledExchange(); + ServerWebExchange mutatedExchange = exchange.mutate().request(exchange.getRequest().mutate().build()).build(); + HttpHeaders headers = mutatedExchange.getRequest().getHeaders(); + assertThatExceptionOfType(ServerExchangeRejectedException.class) + .isThrownBy(() -> headers.get(invalidHeaderName)); + } + + @Test + void getMutatedFirewalledExchangeGetHeaderWhenNotAllowedHeaderNameThenException() { + String invalidHeaderName = "bad name"; + this.firewall.setAllowedHeaderNames((name) -> !name.equals(invalidHeaderName)); + ServerWebExchange exchange = getFirewalledExchange(); + HttpHeaders headers = exchange.getRequest().mutate().build().getHeaders(); + assertThatExceptionOfType(ServerExchangeRejectedException.class) + .isThrownBy(() -> headers.get(invalidHeaderName)); + } + }