parent
8c063f8ccb
commit
b3e0f167ff
|
@ -17,7 +17,6 @@
|
||||||
package org.springframework.security.web.server.header;
|
package org.springframework.security.web.server.header;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
|
||||||
|
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
|
||||||
|
@ -41,8 +40,16 @@ public class StaticServerHttpHeadersWriter implements ServerHttpHeadersWriter {
|
||||||
@Override
|
@Override
|
||||||
public Mono<Void> writeHttpHeaders(ServerWebExchange exchange) {
|
public Mono<Void> writeHttpHeaders(ServerWebExchange exchange) {
|
||||||
HttpHeaders headers = exchange.getResponse().getHeaders();
|
HttpHeaders headers = exchange.getResponse().getHeaders();
|
||||||
boolean containsOneHeaderToAdd = Collections.disjoint(headers.keySet(), this.headersToAdd.keySet());
|
// Note: We need to ensure that the following algorithm compares headers
|
||||||
if (containsOneHeaderToAdd) {
|
// case insensitively, which should be true of headers.containsKey().
|
||||||
|
boolean containsNoHeadersToAdd = true;
|
||||||
|
for (String headerName : this.headersToAdd.keySet()) {
|
||||||
|
if (headers.containsKey(headerName)) {
|
||||||
|
containsNoHeadersToAdd = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (containsNoHeadersToAdd) {
|
||||||
this.headersToAdd.forEach(headers::put);
|
this.headersToAdd.forEach(headers::put);
|
||||||
}
|
}
|
||||||
return Mono.empty();
|
return Mono.empty();
|
||||||
|
|
|
@ -16,11 +16,14 @@
|
||||||
|
|
||||||
package org.springframework.security.web.server.header;
|
package org.springframework.security.web.server.header;
|
||||||
|
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
||||||
import org.springframework.mock.web.server.MockServerWebExchange;
|
import org.springframework.mock.web.server.MockServerWebExchange;
|
||||||
|
import org.springframework.util.LinkedMultiValueMap;
|
||||||
import org.springframework.web.server.ServerWebExchange;
|
import org.springframework.web.server.ServerWebExchange;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
@ -56,6 +59,24 @@ public class StaticServerHttpHeadersWriterTests {
|
||||||
.containsOnly(headerValue);
|
.containsOnly(headerValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gh-10557
|
||||||
|
@Test
|
||||||
|
public void writeHeadersWhenHeaderWrittenWithDifferentCaseThenDoesNotWriteHeaders() {
|
||||||
|
String headerName = HttpHeaders.CACHE_CONTROL.toLowerCase(Locale.ROOT);
|
||||||
|
String headerValue = "max-age=120";
|
||||||
|
this.headers.set(headerName, headerValue);
|
||||||
|
// Note: This test inverts which collection uses case sensitive headers,
|
||||||
|
// due to the fact that gh-10557 reports NettyHeadersAdapter as the
|
||||||
|
// response headers implementation, which is not accessible here.
|
||||||
|
HttpHeaders caseSensitiveHeaders = new HttpHeaders(new LinkedMultiValueMap<>());
|
||||||
|
caseSensitiveHeaders.set(HttpHeaders.CACHE_CONTROL, CacheControlServerHttpHeadersWriter.CACHE_CONTRTOL_VALUE);
|
||||||
|
caseSensitiveHeaders.set(HttpHeaders.PRAGMA, CacheControlServerHttpHeadersWriter.PRAGMA_VALUE);
|
||||||
|
caseSensitiveHeaders.set(HttpHeaders.EXPIRES, CacheControlServerHttpHeadersWriter.EXPIRES_VALUE);
|
||||||
|
this.writer = new StaticServerHttpHeadersWriter(caseSensitiveHeaders);
|
||||||
|
this.writer.writeHttpHeaders(this.exchange);
|
||||||
|
assertThat(this.headers.get(headerName)).containsOnly(headerValue);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void writeHeadersWhenMultiHeaderThenWritesAllHeaders() {
|
public void writeHeadersWhenMultiHeaderThenWritesAllHeaders() {
|
||||||
this.writer = StaticServerHttpHeadersWriter.builder()
|
this.writer = StaticServerHttpHeadersWriter.builder()
|
||||||
|
|
Loading…
Reference in New Issue