Cache headers only if no cache headers set

Fixes: gh-5005
This commit is contained in:
Rob Winch 2018-02-07 14:43:27 -06:00
parent 0e02b489c8
commit f81b58112b
4 changed files with 108 additions and 26 deletions

View File

@ -71,8 +71,8 @@ public class HttpSecurityHeadersTests {
mockMvc.perform(get("/resources/file.js")) mockMvc.perform(get("/resources/file.js"))
.andExpect(status().isOk()) .andExpect(status().isOk())
.andExpect(header().string(HttpHeaders.CACHE_CONTROL,"max-age=12345")) .andExpect(header().string(HttpHeaders.CACHE_CONTROL,"max-age=12345"))
.andExpect(header().string(HttpHeaders.PRAGMA, "")) .andExpect(header().doesNotExist(HttpHeaders.PRAGMA))
.andExpect(header().string(HttpHeaders.EXPIRES, "")); .andExpect(header().doesNotExist(HttpHeaders.EXPIRES));
} }
@Test @Test

View File

@ -15,15 +15,17 @@
*/ */
package org.springframework.security.web.header; package org.springframework.security.web.header;
import org.springframework.util.Assert; import java.io.IOException;
import org.springframework.web.filter.OncePerRequestFilter; import java.util.List;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*; import org.springframework.security.web.util.OnCommittedResponseWrapper;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
/** /**
* Filter implementation to add headers to the current response. Can be useful to add * Filter implementation to add headers to the current response. Can be useful to add
@ -56,12 +58,52 @@ public class HeaderWriterFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain) HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
for (HeaderWriter headerWriter : headerWriters) { HeaderWriterResponse headerWriterResponse = new HeaderWriterResponse(request,
headerWriter.writeHeaders(request, response); response, this.headerWriters);
try {
filterChain.doFilter(request, headerWriterResponse);
}
finally {
headerWriterResponse.writeHeaders();
} }
filterChain.doFilter(request, response);
} }
static class HeaderWriterResponse extends OnCommittedResponseWrapper {
private final HttpServletRequest request;
private final List<HeaderWriter> headerWriters;
HeaderWriterResponse(HttpServletRequest request, HttpServletResponse response,
List<HeaderWriter> headerWriters) {
super(response);
this.request = request;
this.headerWriters = headerWriters;
}
/*
* (non-Javadoc)
*
* @see org.springframework.security.web.util.OnCommittedResponseWrapper#
* onResponseCommitted()
*/
@Override
protected void onResponseCommitted() {
writeHeaders();
this.disableOnResponseCommitted();
}
protected void writeHeaders() {
if (isDisableOnResponseCommitted()) {
return;
}
for (HeaderWriter headerWriter : this.headerWriters) {
headerWriter.writeHeaders(this.request, getHttpResponse());
}
}
private HttpServletResponse getHttpResponse() {
return (HttpServletResponse) getResponse();
}
}
} }

View File

@ -15,21 +15,32 @@
*/ */
package org.springframework.security.web.header; package org.springframework.security.web.header;
import static org.assertj.core.api.Assertions.assertThat; import java.io.IOException;
import static org.mockito.Mockito.verify;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.header.HeaderWriter;
import org.springframework.security.web.header.HeaderWriterFilter; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
/** /**
* Tests for the {@code HeadersFilter} * Tests for the {@code HeadersFilter}
@ -71,9 +82,34 @@ public class HeaderWriterFilterTests {
filter.doFilter(request, response, filterChain); filter.doFilter(request, response, filterChain);
verify(writer1).writeHeaders(request, response); verify(this.writer1).writeHeaders(request, response);
verify(writer2).writeHeaders(request, response); verify(this.writer2).writeHeaders(request, response);
assertThat(filterChain.getRequest()).isEqualTo(request); // verify the filterChain assertThat(filterChain.getRequest()).isEqualTo(request); // verify the filterChain
// continued // continued
} }
// gh-2953
@Test
public void headersDelayed() throws Exception {
HeaderWriterFilter filter = new HeaderWriterFilter(
Arrays.<HeaderWriter>asList(this.writer1));
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, new FilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
verifyZeroInteractions(HeaderWriterFilterTests.this.writer1);
response.flushBuffer();
verify(HeaderWriterFilterTests.this.writer1).writeHeaders(
any(HttpServletRequest.class), any(HttpServletResponse.class));
}
});
verifyNoMoreInteractions(this.writer1);
}
} }

View File

@ -60,11 +60,13 @@ public class CacheControlHeadersWriterTests {
public void writeHeaders() { public void writeHeaders() {
this.writer.writeHeaders(this.request, this.response); this.writer.writeHeaders(this.request, this.response);
assertThat(this.response.getHeaderNames()).hasSize(3); assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
assertThat(this.response.getHeaderValues("Cache-Control")).containsExactly( assertThat(this.response.getHeaderValues("Cache-Control")).containsOnly(
"no-cache, no-store, max-age=0, must-revalidate"); "no-cache, no-store, max-age=0, must-revalidate");
assertThat(this.response.getHeaderValues("Pragma")).containsOnly("no-cache"); assertThat(this.response.getHeaderValues("Pragma"))
assertThat(this.response.getHeaderValues("Expires")).containsOnly("0"); .containsOnly("no-cache");
assertThat(this.response.getHeaderValues("Expires"))
.containsOnly("0");
} }
@Test @Test
@ -78,11 +80,13 @@ public class CacheControlHeadersWriterTests {
this.writer.writeHeaders(this.request, this.response); this.writer.writeHeaders(this.request, this.response);
assertThat(this.response.getHeaderNames()).hasSize(3); assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
assertThat(this.response.getHeaderValues("Cache-Control")).containsExactly( assertThat(this.response.getHeaderValues("Cache-Control"))
"no-cache, no-store, max-age=0, must-revalidate"); .containsOnly("no-cache, no-store, max-age=0, must-revalidate");
assertThat(this.response.getHeaderValues("Pragma")).containsOnly("no-cache"); assertThat(this.response.getHeaderValues("Pragma"))
assertThat(this.response.getHeaderValues("Expires")).containsOnly("0"); .containsOnly("no-cache");
assertThat(this.response.getHeaderValues("Expires"))
.containsOnly("0");
} }
// gh-2953 // gh-2953