Merge pull request #3759 from rwinch/gh-2953

Cache Control only written if not set
This commit is contained in:
Rob Winch 2016-03-15 13:03:58 -05:00
commit 0f2a3b18ce
7 changed files with 407 additions and 152 deletions

View File

@ -17,10 +17,9 @@ package org.springframework.security.web.context;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.util.OnCommittedResponseWrapper;
/**
* Base class for response wrappers which encapsulate the logic for storing a security
@ -40,10 +39,8 @@ import org.springframework.security.core.context.SecurityContextHolder;
* @author Rob Winch
* @since 3.0
*/
public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
OnCommittedResponseWrapper {
private final Log logger = LogFactory.getLog(getClass());
public abstract class SaveContextOnUpdateOrErrorResponseWrapper
extends OnCommittedResponseWrapper {
private boolean contextSaved = false;
/* See SEC-1052 */
@ -86,12 +83,12 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
@Override
protected void onResponseCommitted() {
saveContext(SecurityContextHolder.getContext());
contextSaved = true;
this.contextSaved = true;
}
@Override
public final String encodeRedirectUrl(String url) {
if (disableUrlRewriting) {
if (this.disableUrlRewriting) {
return url;
}
return super.encodeRedirectUrl(url);
@ -99,7 +96,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
@Override
public final String encodeRedirectURL(String url) {
if (disableUrlRewriting) {
if (this.disableUrlRewriting) {
return url;
}
return super.encodeRedirectURL(url);
@ -107,7 +104,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
@Override
public final String encodeUrl(String url) {
if (disableUrlRewriting) {
if (this.disableUrlRewriting) {
return url;
}
return super.encodeUrl(url);
@ -115,7 +112,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
@Override
public final String encodeURL(String url) {
if (disableUrlRewriting) {
if (this.disableUrlRewriting) {
return url;
}
return super.encodeURL(url);
@ -126,6 +123,6 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
* wrapper.
*/
public final boolean isContextSaved() {
return contextSaved;
return this.contextSaved;
}
}

View File

@ -15,15 +15,17 @@
*/
package org.springframework.security.web.header;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import java.util.List;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;
import org.springframework.security.web.util.OnCommittedResponseWrapper;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
/**
* Filter implementation to add headers to the current request. Can be useful to add
@ -56,12 +58,52 @@ public class HeaderWriterFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
throws ServletException, IOException {
for (HeaderWriter headerWriter : headerWriters) {
headerWriter.writeHeaders(request, response);
HeaderWriterResponse headerWriterResponse = new HeaderWriterResponse(request,
response, this.headerWriters);
try {
filterChain.doFilter(request, headerWriterResponse);
}
finally {
headerWriterResponse.writeHeaders();
}
filterChain.doFilter(request, response);
}
static class HeaderWriterResponse extends OnCommittedResponseWrapper {
private final HttpServletRequest request;
private final List<HeaderWriter> headerWriters;
HeaderWriterResponse(HttpServletRequest request, HttpServletResponse response,
List<HeaderWriter> headerWriters) {
super(response);
this.request = request;
this.headerWriters = headerWriters;
}
/*
* (non-Javadoc)
*
* @see org.springframework.security.web.util.OnCommittedResponseWrapper#
* onResponseCommitted()
*/
@Override
protected void onResponseCommitted() {
writeHeaders();
this.disableOnResponseCommitted();
}
protected void writeHeaders() {
if (isDisableOnResponseCommitted()) {
return;
}
for (HeaderWriter headerWriter : this.headerWriters) {
headerWriter.writeHeaders(this.request, getHttpResponse());
}
}
private HttpServletResponse getHttpResponse() {
return (HttpServletResponse) getResponse();
}
}
}

View File

@ -15,14 +15,20 @@
*/
package org.springframework.security.web.header.writers;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.security.web.header.Header;
import org.springframework.security.web.header.HeaderWriter;
import org.springframework.util.ReflectionUtils;
/**
* A {@link StaticHeadersWriter} that inserts headers to prevent caching. Specifically it
* adds the following headers:
* Inserts headers to prevent caching if no cache control headers have been specified.
* Specifically it adds the following headers:
* <ul>
* <li>Cache-Control: no-cache, no-store, max-age=0, must-revalidate</li>
* <li>Pragma: no-cache</li>
@ -32,21 +38,47 @@ import org.springframework.security.web.header.Header;
* @author Rob Winch
* @since 3.2
*/
public final class CacheControlHeadersWriter extends StaticHeadersWriter {
public final class CacheControlHeadersWriter implements HeaderWriter {
private static final String EXPIRES = "Expires";
private static final String PRAGMA = "Pragma";
private static final String CACHE_CONTROL = "Cache-Control";
private final Method getHeaderMethod;
private final HeaderWriter delegate;
/**
* Creates a new instance
*/
public CacheControlHeadersWriter() {
super(createHeaders());
this.delegate = new StaticHeadersWriter(createHeaders());
this.getHeaderMethod = ReflectionUtils.findMethod(HttpServletResponse.class,
"getHeader", String.class);
}
@Override
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
if (hasHeader(response, CACHE_CONTROL) || hasHeader(response, EXPIRES)
|| hasHeader(response, PRAGMA)) {
return;
}
this.delegate.writeHeaders(request, response);
}
private boolean hasHeader(HttpServletResponse response, String headerName) {
if (this.getHeaderMethod == null) {
return false;
}
return ReflectionUtils.invokeMethod(this.getHeaderMethod, response,
headerName) != null;
}
private static List<Header> createHeaders() {
List<Header> headers = new ArrayList<Header>(2);
headers.add(new Header("Cache-Control",
headers.add(new Header(CACHE_CONTROL,
"no-cache, no-store, max-age=0, must-revalidate"));
headers.add(new Header("Pragma", "no-cache"));
headers.add(new Header("Expires", "0"));
headers.add(new Header(PRAGMA, "no-cache"));
headers.add(new Header(EXPIRES, "0"));
return headers;
}
}

View File

@ -13,33 +13,31 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.context;
package org.springframework.security.web.util;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Locale;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
/**
* Base class for response wrappers which encapsulate the logic for handling an event when the
* {@link javax.servlet.http.HttpServletResponse} is committed.
* Base class for response wrappers which encapsulate the logic for handling an event when
* the {@link javax.servlet.http.HttpServletResponse} is committed.
*
* @since 4.0.2
* @author Rob Winch
*/
abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
private final Log logger = LogFactory.getLog(getClass());
public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
private boolean disableOnCommitted;
/**
* The Content-Length response header. If this is greater than 0, then once {@link #contentWritten} is larger than
* or equal the response is considered committed.
* The Content-Length response header. If this is greater than 0, then once
* {@link #contentWritten} is larger than or equal the response is considered
* committed.
*/
private long contentLength;
@ -57,7 +55,7 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
@Override
public void addHeader(String name, String value) {
if("Content-Length".equalsIgnoreCase(name)) {
if ("Content-Length".equalsIgnoreCase(name)) {
setContentLength(Long.parseLong(value));
}
super.addHeader(name, value);
@ -75,22 +73,33 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
}
/**
* Invoke this method to disable invoking {@link OnCommittedResponseWrapper#onResponseCommitted()} when the {@link javax.servlet.http.HttpServletResponse} is
* committed. This can be useful in the event that Async Web Requests are
* made.
* Invoke this method to disable invoking
* {@link OnCommittedResponseWrapper#onResponseCommitted()} when the
* {@link javax.servlet.http.HttpServletResponse} is committed. This can be useful in
* the event that Async Web Requests are made.
*/
public void disableOnResponseCommitted() {
protected void disableOnResponseCommitted() {
this.disableOnCommitted = true;
}
/**
* Implement the logic for handling the {@link javax.servlet.http.HttpServletResponse} being committed
* Returns true if {@link #onResponseCommitted()} will be invoked when the response is
* committed, else false.
* @return if {@link #onResponseCommitted()} is enabled
*/
protected boolean isDisableOnResponseCommitted() {
return this.disableOnCommitted;
}
/**
* Implement the logic for handling the {@link javax.servlet.http.HttpServletResponse}
* being committed
*/
protected abstract void onResponseCommitted();
/**
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
* superclass <code>sendError()</code>
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
* before calling the superclass <code>sendError()</code>
*/
@Override
public final void sendError(int sc) throws IOException {
@ -99,8 +108,8 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
}
/**
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
* superclass <code>sendError()</code>
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
* before calling the superclass <code>sendError()</code>
*/
@Override
public final void sendError(int sc, String msg) throws IOException {
@ -109,8 +118,8 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
}
/**
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
* superclass <code>sendRedirect()</code>
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
* before calling the superclass <code>sendRedirect()</code>
*/
@Override
public final void sendRedirect(String location) throws IOException {
@ -119,8 +128,9 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
}
/**
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the calling
* <code>getOutputStream().close()</code> or <code>getOutputStream().flush()</code>
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
* before calling the calling <code>getOutputStream().close()</code> or
* <code>getOutputStream().flush()</code>
*/
@Override
public ServletOutputStream getOutputStream() throws IOException {
@ -128,8 +138,9 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
}
/**
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
* <code>getWriter().close()</code> or <code>getWriter().flush()</code>
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
* before calling the <code>getWriter().close()</code> or
* <code>getWriter().flush()</code>
*/
@Override
public PrintWriter getWriter() throws IOException {
@ -137,8 +148,8 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
}
/**
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
* superclass <code>flushBuffer()</code>
* Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
* before calling the superclass <code>flushBuffer()</code>
*/
@Override
public void flushBuffer() throws IOException {
@ -187,36 +198,38 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
}
/**
* Adds the contentLengthToWrite to the total contentWritten size and checks to see if the response should be
* written.
* Adds the contentLengthToWrite to the total contentWritten size and checks to see if
* the response should be written.
*
* @param contentLengthToWrite the size of the content that is about to be written.
*/
private void checkContentLength(long contentLengthToWrite) {
contentWritten += contentLengthToWrite;
boolean isBodyFullyWritten = contentLength > 0 && contentWritten >= contentLength;
this.contentWritten += contentLengthToWrite;
boolean isBodyFullyWritten = this.contentLength > 0
&& this.contentWritten >= this.contentLength;
int bufferSize = getBufferSize();
boolean requiresFlush = bufferSize > 0 && contentWritten >= bufferSize;
if(isBodyFullyWritten || requiresFlush) {
boolean requiresFlush = bufferSize > 0 && this.contentWritten >= bufferSize;
if (isBodyFullyWritten || requiresFlush) {
doOnResponseCommitted();
}
}
/**
* Calls <code>onResponseCommmitted()</code> with the current contents as long as
* {@link #disableOnResponseCommitted()()} was not invoked.
* {@link #disableOnResponseCommitted()} was not invoked.
*/
private void doOnResponseCommitted() {
if(!disableOnCommitted) {
if (!this.disableOnCommitted) {
onResponseCommitted();
disableOnResponseCommitted();
}
}
/**
* Ensures {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the prior to methods that commit the response. We delegate all methods
* to the original {@link java.io.PrintWriter} to ensure that the behavior is as close to the original {@link java.io.PrintWriter}
* as possible. See SEC-2039
* Ensures {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before
* calling the prior to methods that commit the response. We delegate all methods to
* the original {@link java.io.PrintWriter} to ensure that the behavior is as close to
* the original {@link java.io.PrintWriter} as possible. See SEC-2039
* @author Rob Winch
*/
private class SaveContextPrintWriter extends PrintWriter {
@ -227,197 +240,235 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
this.delegate = delegate;
}
@Override
public void flush() {
doOnResponseCommitted();
delegate.flush();
this.delegate.flush();
}
@Override
public void close() {
doOnResponseCommitted();
delegate.close();
this.delegate.close();
}
@Override
public int hashCode() {
return delegate.hashCode();
return this.delegate.hashCode();
}
@Override
public boolean equals(Object obj) {
return delegate.equals(obj);
return this.delegate.equals(obj);
}
@Override
public String toString() {
return getClass().getName() + "[delegate=" + delegate.toString() + "]";
return getClass().getName() + "[delegate=" + this.delegate.toString() + "]";
}
@Override
public boolean checkError() {
return delegate.checkError();
return this.delegate.checkError();
}
@Override
public void write(int c) {
trackContentLength(c);
delegate.write(c);
this.delegate.write(c);
}
@Override
public void write(char[] buf, int off, int len) {
checkContentLength(len);
delegate.write(buf, off, len);
this.delegate.write(buf, off, len);
}
@Override
public void write(char[] buf) {
trackContentLength(buf);
delegate.write(buf);
this.delegate.write(buf);
}
@Override
public void write(String s, int off, int len) {
checkContentLength(len);
delegate.write(s, off, len);
this.delegate.write(s, off, len);
}
@Override
public void write(String s) {
trackContentLength(s);
delegate.write(s);
this.delegate.write(s);
}
@Override
public void print(boolean b) {
trackContentLength(b);
delegate.print(b);
this.delegate.print(b);
}
@Override
public void print(char c) {
trackContentLength(c);
delegate.print(c);
this.delegate.print(c);
}
@Override
public void print(int i) {
trackContentLength(i);
delegate.print(i);
this.delegate.print(i);
}
@Override
public void print(long l) {
trackContentLength(l);
delegate.print(l);
this.delegate.print(l);
}
@Override
public void print(float f) {
trackContentLength(f);
delegate.print(f);
this.delegate.print(f);
}
@Override
public void print(double d) {
trackContentLength(d);
delegate.print(d);
this.delegate.print(d);
}
@Override
public void print(char[] s) {
trackContentLength(s);
delegate.print(s);
this.delegate.print(s);
}
@Override
public void print(String s) {
trackContentLength(s);
delegate.print(s);
this.delegate.print(s);
}
@Override
public void print(Object obj) {
trackContentLength(obj);
delegate.print(obj);
this.delegate.print(obj);
}
@Override
public void println() {
trackContentLengthLn();
delegate.println();
this.delegate.println();
}
@Override
public void println(boolean x) {
trackContentLength(x);
trackContentLengthLn();
delegate.println(x);
this.delegate.println(x);
}
@Override
public void println(char x) {
trackContentLength(x);
trackContentLengthLn();
delegate.println(x);
this.delegate.println(x);
}
@Override
public void println(int x) {
trackContentLength(x);
trackContentLengthLn();
delegate.println(x);
this.delegate.println(x);
}
@Override
public void println(long x) {
trackContentLength(x);
trackContentLengthLn();
delegate.println(x);
this.delegate.println(x);
}
@Override
public void println(float x) {
trackContentLength(x);
trackContentLengthLn();
delegate.println(x);
this.delegate.println(x);
}
@Override
public void println(double x) {
trackContentLength(x);
trackContentLengthLn();
delegate.println(x);
this.delegate.println(x);
}
@Override
public void println(char[] x) {
trackContentLength(x);
trackContentLengthLn();
delegate.println(x);
this.delegate.println(x);
}
@Override
public void println(String x) {
trackContentLength(x);
trackContentLengthLn();
delegate.println(x);
this.delegate.println(x);
}
@Override
public void println(Object x) {
trackContentLength(x);
trackContentLengthLn();
delegate.println(x);
this.delegate.println(x);
}
@Override
public PrintWriter printf(String format, Object... args) {
return delegate.printf(format, args);
return this.delegate.printf(format, args);
}
@Override
public PrintWriter printf(Locale l, String format, Object... args) {
return delegate.printf(l, format, args);
return this.delegate.printf(l, format, args);
}
@Override
public PrintWriter format(String format, Object... args) {
return delegate.format(format, args);
return this.delegate.format(format, args);
}
@Override
public PrintWriter format(Locale l, String format, Object... args) {
return delegate.format(l, format, args);
return this.delegate.format(l, format, args);
}
@Override
public PrintWriter append(CharSequence csq) {
checkContentLength(csq.length());
return delegate.append(csq);
return this.delegate.append(csq);
}
@Override
public PrintWriter append(CharSequence csq, int start, int end) {
checkContentLength(end - start);
return delegate.append(csq, start, end);
return this.delegate.append(csq, start, end);
}
@Override
public PrintWriter append(char c) {
trackContentLength(c);
return delegate.append(c);
return this.delegate.append(c);
}
}
/**
* Ensures{@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling methods that commit the response. We delegate all methods
* to the original {@link javax.servlet.ServletOutputStream} to ensure that the behavior is as close to the original {@link javax.servlet.ServletOutputStream}
* as possible. See SEC-2039
* Ensures{@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before
* calling methods that commit the response. We delegate all methods to the original
* {@link javax.servlet.ServletOutputStream} to ensure that the behavior is as close
* to the original {@link javax.servlet.ServletOutputStream} as possible. See SEC-2039
*
* @author Rob Winch
*/
@ -428,123 +479,146 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
this.delegate = delegate;
}
@Override
public void write(int b) throws IOException {
trackContentLength(b);
this.delegate.write(b);
}
@Override
public void flush() throws IOException {
doOnResponseCommitted();
delegate.flush();
this.delegate.flush();
}
@Override
public void close() throws IOException {
doOnResponseCommitted();
delegate.close();
this.delegate.close();
}
@Override
public int hashCode() {
return delegate.hashCode();
return this.delegate.hashCode();
}
@Override
public boolean equals(Object obj) {
return delegate.equals(obj);
return this.delegate.equals(obj);
}
@Override
public void print(boolean b) throws IOException {
trackContentLength(b);
delegate.print(b);
this.delegate.print(b);
}
@Override
public void print(char c) throws IOException {
trackContentLength(c);
delegate.print(c);
this.delegate.print(c);
}
@Override
public void print(double d) throws IOException {
trackContentLength(d);
delegate.print(d);
this.delegate.print(d);
}
@Override
public void print(float f) throws IOException {
trackContentLength(f);
delegate.print(f);
this.delegate.print(f);
}
@Override
public void print(int i) throws IOException {
trackContentLength(i);
delegate.print(i);
this.delegate.print(i);
}
@Override
public void print(long l) throws IOException {
trackContentLength(l);
delegate.print(l);
this.delegate.print(l);
}
@Override
public void print(String s) throws IOException {
trackContentLength(s);
delegate.print(s);
this.delegate.print(s);
}
@Override
public void println() throws IOException {
trackContentLengthLn();
delegate.println();
this.delegate.println();
}
@Override
public void println(boolean b) throws IOException {
trackContentLength(b);
trackContentLengthLn();
delegate.println(b);
this.delegate.println(b);
}
@Override
public void println(char c) throws IOException {
trackContentLength(c);
trackContentLengthLn();
delegate.println(c);
this.delegate.println(c);
}
@Override
public void println(double d) throws IOException {
trackContentLength(d);
trackContentLengthLn();
delegate.println(d);
this.delegate.println(d);
}
@Override
public void println(float f) throws IOException {
trackContentLength(f);
trackContentLengthLn();
delegate.println(f);
this.delegate.println(f);
}
@Override
public void println(int i) throws IOException {
trackContentLength(i);
trackContentLengthLn();
delegate.println(i);
this.delegate.println(i);
}
@Override
public void println(long l) throws IOException {
trackContentLength(l);
trackContentLengthLn();
delegate.println(l);
this.delegate.println(l);
}
@Override
public void println(String s) throws IOException {
trackContentLength(s);
trackContentLengthLn();
delegate.println(s);
this.delegate.println(s);
}
@Override
public void write(byte[] b) throws IOException {
trackContentLength(b);
delegate.write(b);
this.delegate.write(b);
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
checkContentLength(len);
delegate.write(b, off, len);
this.delegate.write(b, off, len);
}
@Override
public String toString() {
return getClass().getName() + "[delegate=" + delegate.toString() + "]";
return getClass().getName() + "[delegate=" + this.delegate.toString() + "]";
}
}
}

View File

@ -15,21 +15,32 @@
*/
package org.springframework.security.web.header;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.verify;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.header.HeaderWriter;
import org.springframework.security.web.header.HeaderWriterFilter;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
/**
* Tests for the {@code HeadersFilter}
@ -60,8 +71,8 @@ public class HeaderWriterFilterTests {
@Test
public void additionalHeadersShouldBeAddedToTheResponse() throws Exception {
List<HeaderWriter> headerWriters = new ArrayList<HeaderWriter>();
headerWriters.add(writer1);
headerWriters.add(writer2);
headerWriters.add(this.writer1);
headerWriters.add(this.writer2);
HeaderWriterFilter filter = new HeaderWriterFilter(headerWriters);
@ -71,9 +82,34 @@ public class HeaderWriterFilterTests {
filter.doFilter(request, response, filterChain);
verify(writer1).writeHeaders(request, response);
verify(writer2).writeHeaders(request, response);
verify(this.writer1).writeHeaders(request, response);
verify(this.writer2).writeHeaders(request, response);
assertThat(filterChain.getRequest()).isEqualTo(request); // verify the filterChain
// continued
}
// gh-2953
@Test
public void headersDelayed() throws Exception {
HeaderWriterFilter filter = new HeaderWriterFilter(
Arrays.<HeaderWriter>asList(this.writer1));
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(request, response, new FilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
verifyZeroInteractions(HeaderWriterFilterTests.this.writer1);
response.flushBuffer();
verify(HeaderWriterFilterTests.this.writer1).writeHeaders(
any(HttpServletRequest.class), any(HttpServletResponse.class));
}
});
verifyNoMoreInteractions(this.writer1);
}
}

View File

@ -15,19 +15,32 @@
*/
package org.springframework.security.web.header.writers;
import static org.assertj.core.api.Assertions.assertThat;
import java.util.Arrays;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.when;
import static org.powermock.api.mockito.PowerMockito.spy;
/**
* @author Rob Winch
*
*/
@RunWith(PowerMockRunner.class)
@PrepareOnlyThisForTest(ReflectionUtils.class)
public class CacheControlHeadersWriterTests {
private MockHttpServletRequest request;
@ -38,20 +51,79 @@ public class CacheControlHeadersWriterTests {
@Before
public void setup() {
request = new MockHttpServletRequest();
response = new MockHttpServletResponse();
writer = new CacheControlHeadersWriter();
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.writer = new CacheControlHeadersWriter();
}
@Test
public void writeHeaders() {
writer.writeHeaders(request, response);
this.writer.writeHeaders(this.request, this.response);
assertThat(response.getHeaderNames().size()).isEqualTo(3);
assertThat(response.getHeaderValues("Cache-Control")).isEqualTo(
assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo(
Arrays.asList("no-cache, no-store, max-age=0, must-revalidate"));
assertThat(response.getHeaderValues("Pragma")).isEqualTo(
Arrays.asList("no-cache"));
assertThat(response.getHeaderValues("Expires")).isEqualTo(Arrays.asList("0"));
assertThat(this.response.getHeaderValues("Pragma"))
.isEqualTo(Arrays.asList("no-cache"));
assertThat(this.response.getHeaderValues("Expires"))
.isEqualTo(Arrays.asList("0"));
}
@Test
public void writeHeadersServlet25() {
spy(ReflectionUtils.class);
when(ReflectionUtils.findMethod(HttpServletResponse.class, "getHeader",
String.class)).thenReturn(null);
this.response = spy(this.response);
doThrow(NoSuchMethodError.class).when(this.response).getHeader(anyString());
this.writer = new CacheControlHeadersWriter();
this.writer.writeHeaders(this.request, this.response);
assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo(
Arrays.asList("no-cache, no-store, max-age=0, must-revalidate"));
assertThat(this.response.getHeaderValues("Pragma"))
.isEqualTo(Arrays.asList("no-cache"));
assertThat(this.response.getHeaderValues("Expires"))
.isEqualTo(Arrays.asList("0"));
}
// gh-2953
@Test
public void writeHeadersDisabledIfCacheControl() {
this.response.setHeader("Cache-Control", "max-age: 123");
this.writer.writeHeaders(this.request, this.response);
assertThat(this.response.getHeaderNames()).hasSize(1);
assertThat(this.response.getHeaderValues("Cache-Control"))
.containsOnly("max-age: 123");
assertThat(this.response.getHeaderValue("Pragma")).isNull();
assertThat(this.response.getHeaderValue("Expires")).isNull();
}
@Test
public void writeHeadersDisabledIfPragma() {
this.response.setHeader("Pragma", "mock");
this.writer.writeHeaders(this.request, this.response);
assertThat(this.response.getHeaderNames()).hasSize(1);
assertThat(this.response.getHeaderValues("Pragma")).containsOnly("mock");
assertThat(this.response.getHeaderValue("Expires")).isNull();
assertThat(this.response.getHeaderValue("Cache-Control")).isNull();
}
@Test
public void writeHeadersDisabledIfExpires() {
this.response.setHeader("Expires", "mock");
this.writer.writeHeaders(this.request, this.response);
assertThat(this.response.getHeaderNames()).hasSize(1);
assertThat(this.response.getHeaderValues("Expires")).containsOnly("mock");
assertThat(this.response.getHeaderValue("Cache-Control")).isNull();
assertThat(this.response.getHeaderValue("Pragma")).isNull();
}
}

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.context;
package org.springframework.security.web.util;
import java.io.IOException;
import java.io.PrintWriter;
@ -25,6 +25,8 @@ import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.security.web.util.OnCommittedResponseWrapper;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;