From f81b58112b372c48b94e9b0260ecee0853d94243 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Wed, 7 Feb 2018 14:43:27 -0600 Subject: [PATCH] Cache headers only if no cache headers set Fixes: gh-5005 --- .../web/HttpSecurityHeadersTests.java | 4 +- .../web/header/HeaderWriterFilter.java | 58 ++++++++++++++++--- .../web/header/HeaderWriterFilterTests.java | 50 +++++++++++++--- .../CacheControlHeadersWriterTests.java | 22 ++++--- 4 files changed, 108 insertions(+), 26 deletions(-) diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/HttpSecurityHeadersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/HttpSecurityHeadersTests.java index 79e9678188..c25590ea9a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/HttpSecurityHeadersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/HttpSecurityHeadersTests.java @@ -71,8 +71,8 @@ public class HttpSecurityHeadersTests { mockMvc.perform(get("/resources/file.js")) .andExpect(status().isOk()) .andExpect(header().string(HttpHeaders.CACHE_CONTROL,"max-age=12345")) - .andExpect(header().string(HttpHeaders.PRAGMA, "")) - .andExpect(header().string(HttpHeaders.EXPIRES, "")); + .andExpect(header().doesNotExist(HttpHeaders.PRAGMA)) + .andExpect(header().doesNotExist(HttpHeaders.EXPIRES)); } @Test diff --git a/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java b/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java index 1fbd0c256e..d963c79748 100644 --- a/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java +++ b/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java @@ -15,15 +15,17 @@ */ package org.springframework.security.web.header; -import org.springframework.util.Assert; -import org.springframework.web.filter.OncePerRequestFilter; +import java.io.IOException; +import java.util.List; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.util.*; + +import org.springframework.security.web.util.OnCommittedResponseWrapper; +import org.springframework.util.Assert; +import org.springframework.web.filter.OncePerRequestFilter; /** * Filter implementation to add headers to the current response. Can be useful to add @@ -56,12 +58,52 @@ public class HeaderWriterFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { + throws ServletException, IOException { - for (HeaderWriter headerWriter : headerWriters) { - headerWriter.writeHeaders(request, response); + HeaderWriterResponse headerWriterResponse = new HeaderWriterResponse(request, + response, this.headerWriters); + try { + filterChain.doFilter(request, headerWriterResponse); + } + finally { + headerWriterResponse.writeHeaders(); } - filterChain.doFilter(request, response); } + static class HeaderWriterResponse extends OnCommittedResponseWrapper { + private final HttpServletRequest request; + private final List headerWriters; + + HeaderWriterResponse(HttpServletRequest request, HttpServletResponse response, + List headerWriters) { + super(response); + this.request = request; + this.headerWriters = headerWriters; + } + + /* + * (non-Javadoc) + * + * @see org.springframework.security.web.util.OnCommittedResponseWrapper# + * onResponseCommitted() + */ + @Override + protected void onResponseCommitted() { + writeHeaders(); + this.disableOnResponseCommitted(); + } + + protected void writeHeaders() { + if (isDisableOnResponseCommitted()) { + return; + } + for (HeaderWriter headerWriter : this.headerWriters) { + headerWriter.writeHeaders(this.request, getHttpResponse()); + } + } + + private HttpServletResponse getHttpResponse() { + return (HttpServletResponse) getResponse(); + } + } } diff --git a/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java b/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java index a6e9c80a54..32d7ead22e 100644 --- a/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java @@ -15,21 +15,32 @@ */ package org.springframework.security.web.header; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.verify; - +import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; + import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.web.header.HeaderWriter; -import org.springframework.security.web.header.HeaderWriterFilter; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; /** * Tests for the {@code HeadersFilter} @@ -71,9 +82,34 @@ public class HeaderWriterFilterTests { filter.doFilter(request, response, filterChain); - verify(writer1).writeHeaders(request, response); - verify(writer2).writeHeaders(request, response); + verify(this.writer1).writeHeaders(request, response); + verify(this.writer2).writeHeaders(request, response); assertThat(filterChain.getRequest()).isEqualTo(request); // verify the filterChain // continued } + + // gh-2953 + @Test + public void headersDelayed() throws Exception { + HeaderWriterFilter filter = new HeaderWriterFilter( + Arrays.asList(this.writer1)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + filter.doFilter(request, response, new FilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + verifyZeroInteractions(HeaderWriterFilterTests.this.writer1); + + response.flushBuffer(); + + verify(HeaderWriterFilterTests.this.writer1).writeHeaders( + any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + }); + + verifyNoMoreInteractions(this.writer1); + } } diff --git a/web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java b/web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java index 3b6ed7cac5..ee1dc1ccf9 100644 --- a/web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java +++ b/web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java @@ -60,11 +60,13 @@ public class CacheControlHeadersWriterTests { public void writeHeaders() { this.writer.writeHeaders(this.request, this.response); - assertThat(this.response.getHeaderNames()).hasSize(3); - assertThat(this.response.getHeaderValues("Cache-Control")).containsExactly( + assertThat(this.response.getHeaderNames().size()).isEqualTo(3); + assertThat(this.response.getHeaderValues("Cache-Control")).containsOnly( "no-cache, no-store, max-age=0, must-revalidate"); - assertThat(this.response.getHeaderValues("Pragma")).containsOnly("no-cache"); - assertThat(this.response.getHeaderValues("Expires")).containsOnly("0"); + assertThat(this.response.getHeaderValues("Pragma")) + .containsOnly("no-cache"); + assertThat(this.response.getHeaderValues("Expires")) + .containsOnly("0"); } @Test @@ -78,11 +80,13 @@ public class CacheControlHeadersWriterTests { this.writer.writeHeaders(this.request, this.response); - assertThat(this.response.getHeaderNames()).hasSize(3); - assertThat(this.response.getHeaderValues("Cache-Control")).containsExactly( - "no-cache, no-store, max-age=0, must-revalidate"); - assertThat(this.response.getHeaderValues("Pragma")).containsOnly("no-cache"); - assertThat(this.response.getHeaderValues("Expires")).containsOnly("0"); + assertThat(this.response.getHeaderNames().size()).isEqualTo(3); + assertThat(this.response.getHeaderValues("Cache-Control")) + .containsOnly("no-cache, no-store, max-age=0, must-revalidate"); + assertThat(this.response.getHeaderValues("Pragma")) + .containsOnly("no-cache"); + assertThat(this.response.getHeaderValues("Expires")) + .containsOnly("0"); } // gh-2953