From fcc9a34356817d93c24b5ccf3107ec234a28b136 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 14 Jul 2015 14:41:50 -0500 Subject: [PATCH] SEC-2973: Add OnCommittedResponseWrapper This ensures that Spring Session & Security's logic for performing a save on the response being committed can easily be kept in synch. Further this ensures that the SecurityContext is now persisted when the response body meets the content length. --- .../context/OnCommittedResponseWrapper.java | 549 ++++++++ ...ContextOnUpdateOrErrorResponseWrapper.java | 392 +----- .../OnCommittedResponseWrapperTests.java | 1122 +++++++++++++++++ 3 files changed, 1699 insertions(+), 364 deletions(-) create mode 100644 web/src/main/java/org/springframework/security/web/context/OnCommittedResponseWrapper.java create mode 100644 web/src/test/java/org/springframework/security/web/context/OnCommittedResponseWrapperTests.java diff --git a/web/src/main/java/org/springframework/security/web/context/OnCommittedResponseWrapper.java b/web/src/main/java/org/springframework/security/web/context/OnCommittedResponseWrapper.java new file mode 100644 index 0000000000..0b388baf38 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/context/OnCommittedResponseWrapper.java @@ -0,0 +1,549 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.springframework.security.web.context; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Locale; + +/** + * Base class for response wrappers which encapsulate the logic for handling an event when the + * {@link javax.servlet.http.HttpServletResponse} is committed. + * + * @since 4.0.2 + * @author Rob Winch + */ +abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { + private final Log logger = LogFactory.getLog(getClass()); + + private boolean disableOnCommitted; + + /** + * The Content-Length response header. If this is greater than 0, then once {@link #contentWritten} is larger than + * or equal the response is considered committed. + */ + private long contentLength; + + /** + * The size of data written to the response body. + */ + private long contentWritten; + + /** + * @param response the response to be wrapped + */ + public OnCommittedResponseWrapper(HttpServletResponse response) { + super(response); + } + + @Override + public void addHeader(String name, String value) { + if("Content-Length".equalsIgnoreCase(name)) { + setContentLength(Long.parseLong(value)); + } + super.addHeader(name, value); + } + + @Override + public void setContentLength(int len) { + setContentLength((long) len); + super.setContentLength(len); + } + + private void setContentLength(long len) { + this.contentLength = len; + checkContentLength(0); + } + + /** + * Invoke this method to disable invoking {@link OnCommittedResponseWrapper#onResponseCommitted()} when the {@link javax.servlet.http.HttpServletResponse} is + * committed. This can be useful in the event that Async Web Requests are + * made. + */ + public void disableOnResponseCommitted() { + this.disableOnCommitted = true; + } + + /** + * Implement the logic for handling the {@link javax.servlet.http.HttpServletResponse} being committed + */ + protected abstract void onResponseCommitted(); + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * superclass sendError() + */ + @Override + public final void sendError(int sc) throws IOException { + doOnResponseCommitted(); + super.sendError(sc); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * superclass sendError() + */ + @Override + public final void sendError(int sc, String msg) throws IOException { + doOnResponseCommitted(); + super.sendError(sc, msg); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * superclass sendRedirect() + */ + @Override + public final void sendRedirect(String location) throws IOException { + doOnResponseCommitted(); + super.sendRedirect(location); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the calling + * getOutputStream().close() or getOutputStream().flush() + */ + @Override + public ServletOutputStream getOutputStream() throws IOException { + return new SaveContextServletOutputStream(super.getOutputStream()); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * getWriter().close() or getWriter().flush() + */ + @Override + public PrintWriter getWriter() throws IOException { + return new SaveContextPrintWriter(super.getWriter()); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * superclass flushBuffer() + */ + @Override + public void flushBuffer() throws IOException { + doOnResponseCommitted(); + super.flushBuffer(); + } + + private void trackContentLength(boolean content) { + checkContentLength(content ? 4 : 5); // TODO Localization + } + + private void trackContentLength(char content) { + checkContentLength(1); + } + + private void trackContentLength(Object content) { + trackContentLength(String.valueOf(content)); + } + + private void trackContentLength(byte[] content) { + checkContentLength(content == null ? 0 : content.length); + } + + private void trackContentLength(char[] content) { + checkContentLength(content == null ? 0 : content.length); + } + + private void trackContentLength(int content) { + trackContentLength(String.valueOf(content)); + } + + private void trackContentLength(float content) { + trackContentLength(String.valueOf(content)); + } + + private void trackContentLength(double content) { + trackContentLength(String.valueOf(content)); + } + + private void trackContentLengthLn() { + trackContentLength("\r\n"); + } + + private void trackContentLength(String content) { + checkContentLength(content.length()); + } + + /** + * Adds the contentLengthToWrite to the total contentWritten size and checks to see if the response should be + * written. + * + * @param contentLengthToWrite the size of the content that is about to be written. + */ + private void checkContentLength(long contentLengthToWrite) { + contentWritten += contentLengthToWrite; + boolean isBodyFullyWritten = contentLength > 0 && contentWritten >= contentLength; + int bufferSize = getBufferSize(); + boolean requiresFlush = bufferSize > 0 && contentWritten >= bufferSize; + if(isBodyFullyWritten || requiresFlush) { + doOnResponseCommitted(); + } + } + + /** + * Calls onResponseCommmitted() with the current contents as long as + * {@link #disableOnResponseCommitted()()} was not invoked. + */ + private void doOnResponseCommitted() { + if(!disableOnCommitted) { + onResponseCommitted(); + disableOnResponseCommitted(); + } else if(logger.isDebugEnabled()){ + logger.debug("Skip invoking on"); + } + } + + /** + * Ensures {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the prior to methods that commit the response. We delegate all methods + * to the original {@link java.io.PrintWriter} to ensure that the behavior is as close to the original {@link java.io.PrintWriter} + * as possible. See SEC-2039 + * @author Rob Winch + */ + private class SaveContextPrintWriter extends PrintWriter { + private final PrintWriter delegate; + + public SaveContextPrintWriter(PrintWriter delegate) { + super(delegate); + this.delegate = delegate; + } + + public void flush() { + doOnResponseCommitted(); + delegate.flush(); + } + + public void close() { + doOnResponseCommitted(); + delegate.close(); + } + + public int hashCode() { + return delegate.hashCode(); + } + + public boolean equals(Object obj) { + return delegate.equals(obj); + } + + public String toString() { + return getClass().getName() + "[delegate=" + delegate.toString() + "]"; + } + + public boolean checkError() { + return delegate.checkError(); + } + + public void write(int c) { + trackContentLength(c); + delegate.write(c); + } + + public void write(char[] buf, int off, int len) { + checkContentLength(len); + delegate.write(buf, off, len); + } + + public void write(char[] buf) { + trackContentLength(buf); + delegate.write(buf); + } + + public void write(String s, int off, int len) { + checkContentLength(len); + delegate.write(s, off, len); + } + + public void write(String s) { + trackContentLength(s); + delegate.write(s); + } + + public void print(boolean b) { + trackContentLength(b); + delegate.print(b); + } + + public void print(char c) { + trackContentLength(c); + delegate.print(c); + } + + public void print(int i) { + trackContentLength(i); + delegate.print(i); + } + + public void print(long l) { + trackContentLength(l); + delegate.print(l); + } + + public void print(float f) { + trackContentLength(f); + delegate.print(f); + } + + public void print(double d) { + trackContentLength(d); + delegate.print(d); + } + + public void print(char[] s) { + trackContentLength(s); + delegate.print(s); + } + + public void print(String s) { + trackContentLength(s); + delegate.print(s); + } + + public void print(Object obj) { + trackContentLength(obj); + delegate.print(obj); + } + + public void println() { + trackContentLengthLn(); + delegate.println(); + } + + public void println(boolean x) { + trackContentLength(x); + trackContentLengthLn(); + delegate.println(x); + } + + public void println(char x) { + trackContentLength(x); + trackContentLengthLn(); + delegate.println(x); + } + + public void println(int x) { + trackContentLength(x); + trackContentLengthLn(); + delegate.println(x); + } + + public void println(long x) { + trackContentLength(x); + trackContentLengthLn(); + delegate.println(x); + } + + public void println(float x) { + trackContentLength(x); + trackContentLengthLn(); + delegate.println(x); + } + + public void println(double x) { + trackContentLength(x); + trackContentLengthLn(); + delegate.println(x); + } + + public void println(char[] x) { + trackContentLength(x); + trackContentLengthLn(); + delegate.println(x); + } + + public void println(String x) { + trackContentLength(x); + trackContentLengthLn(); + delegate.println(x); + } + + public void println(Object x) { + trackContentLength(x); + trackContentLengthLn(); + delegate.println(x); + } + + public PrintWriter printf(String format, Object... args) { + return delegate.printf(format, args); + } + + public PrintWriter printf(Locale l, String format, Object... args) { + return delegate.printf(l, format, args); + } + + public PrintWriter format(String format, Object... args) { + return delegate.format(format, args); + } + + public PrintWriter format(Locale l, String format, Object... args) { + return delegate.format(l, format, args); + } + + public PrintWriter append(CharSequence csq) { + checkContentLength(csq.length()); + return delegate.append(csq); + } + + public PrintWriter append(CharSequence csq, int start, int end) { + checkContentLength(end - start); + return delegate.append(csq, start, end); + } + + public PrintWriter append(char c) { + trackContentLength(c); + return delegate.append(c); + } + } + + /** + * Ensures{@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling methods that commit the response. We delegate all methods + * to the original {@link javax.servlet.ServletOutputStream} to ensure that the behavior is as close to the original {@link javax.servlet.ServletOutputStream} + * as possible. See SEC-2039 + * + * @author Rob Winch + */ + private class SaveContextServletOutputStream extends ServletOutputStream { + private final ServletOutputStream delegate; + + public SaveContextServletOutputStream(ServletOutputStream delegate) { + this.delegate = delegate; + } + + public void write(int b) throws IOException { + trackContentLength(b); + this.delegate.write(b); + } + + public void flush() throws IOException { + doOnResponseCommitted(); + delegate.flush(); + } + + public void close() throws IOException { + doOnResponseCommitted(); + delegate.close(); + } + + public int hashCode() { + return delegate.hashCode(); + } + + public boolean equals(Object obj) { + return delegate.equals(obj); + } + + public void print(boolean b) throws IOException { + trackContentLength(b); + delegate.print(b); + } + + public void print(char c) throws IOException { + trackContentLength(c); + delegate.print(c); + } + + public void print(double d) throws IOException { + trackContentLength(d); + delegate.print(d); + } + + public void print(float f) throws IOException { + trackContentLength(f); + delegate.print(f); + } + + public void print(int i) throws IOException { + trackContentLength(i); + delegate.print(i); + } + + public void print(long l) throws IOException { + trackContentLength(l); + delegate.print(l); + } + + public void print(String s) throws IOException { + trackContentLength(s); + delegate.print(s); + } + + public void println() throws IOException { + trackContentLengthLn(); + delegate.println(); + } + + public void println(boolean b) throws IOException { + trackContentLength(b); + trackContentLengthLn(); + delegate.println(b); + } + + public void println(char c) throws IOException { + trackContentLength(c); + trackContentLengthLn(); + delegate.println(c); + } + + public void println(double d) throws IOException { + trackContentLength(d); + trackContentLengthLn(); + delegate.println(d); + } + + public void println(float f) throws IOException { + trackContentLength(f); + trackContentLengthLn(); + delegate.println(f); + } + + public void println(int i) throws IOException { + trackContentLength(i); + trackContentLengthLn(); + delegate.println(i); + } + + public void println(long l) throws IOException { + trackContentLength(l); + trackContentLengthLn(); + delegate.println(l); + } + + public void println(String s) throws IOException { + trackContentLength(s); + trackContentLengthLn(); + delegate.println(s); + } + + public void write(byte[] b) throws IOException { + trackContentLength(b); + delegate.write(b); + } + + public void write(byte[] b, int off, int len) throws IOException { + checkContentLength(len); + delegate.write(b, off, len); + } + + public String toString() { + return getClass().getName() + "[delegate=" + delegate.toString() + "]"; + } + } +} \ No newline at end of file diff --git a/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java b/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java index 0c9983b4bf..701f50bfb5 100644 --- a/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java +++ b/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java @@ -12,13 +12,7 @@ */ package org.springframework.security.web.context; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.Locale; - -import javax.servlet.ServletOutputStream; import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpServletResponseWrapper; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -26,11 +20,13 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; /** - * Base class for response wrappers which encapsulate the logic for storing a security context and which store the - * SecurityContext when a sendError(), sendRedirect, - * getOutputStream().close(), getOutputStream().flush(), getWriter().close(), or - * getWriter().flush() happens on the same thread that this - * {@link SaveContextOnUpdateOrErrorResponseWrapper} was created. See issue SEC-398 and SEC-2005. + * Base class for response wrappers which encapsulate the logic for storing a security + * context and which store the SecurityContext when a + * sendError(), sendRedirect, + * getOutputStream().close(), getOutputStream().flush(), + * getWriter().close(), or getWriter().flush() happens on the + * same thread that this {@link SaveContextOnUpdateOrErrorResponseWrapper} was created. + * See issue SEC-398 and SEC-2005. *

* Sub-classes should implement the {@link #saveContext(SecurityContext context)} method. *

@@ -41,33 +37,35 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Rob Winch * @since 3.0 */ -public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends HttpServletResponseWrapper { +public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends + OnCommittedResponseWrapper { private final Log logger = LogFactory.getLog(getClass()); - private boolean disableSaveOnResponseCommitted; private boolean contextSaved = false; /* See SEC-1052 */ private final boolean disableUrlRewriting; /** - * @param response the response to be wrapped - * @param disableUrlRewriting turns the URL encoding methods into null operations, preventing the use - * of URL rewriting to add the session identifier as a URL parameter. + * @param response the response to be wrapped + * @param disableUrlRewriting turns the URL encoding methods into null operations, + * preventing the use of URL rewriting to add the session identifier as a URL + * parameter. */ - public SaveContextOnUpdateOrErrorResponseWrapper(HttpServletResponse response, boolean disableUrlRewriting) { + public SaveContextOnUpdateOrErrorResponseWrapper(HttpServletResponse response, + boolean disableUrlRewriting) { super(response); this.disableUrlRewriting = disableUrlRewriting; } /** - * Invoke this method to disable automatic saving of the - * {@link SecurityContext} when the {@link HttpServletResponse} is - * committed. This can be useful in the event that Async Web Requests are - * made which may no longer contain the {@link SecurityContext} on it. + * Invoke this method to disable automatic saving of the {@link SecurityContext} when + * the {@link HttpServletResponse} is committed. This can be useful in the event that + * Async Web Requests are made which may no longer contain the {@link SecurityContext} + * on it. */ public void disableSaveOnResponseCommitted() { - this.disableSaveOnResponseCommitted = true; + disableOnResponseCommitted(); } /** @@ -77,76 +75,15 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends HttpServ */ protected abstract void saveContext(SecurityContext context); - /** - * Makes sure the session is updated before calling the - * superclass sendError() - */ - @Override - public final void sendError(int sc) throws IOException { - doSaveContext(); - super.sendError(sc); - } - - /** - * Makes sure the session is updated before calling the - * superclass sendError() - */ - @Override - public final void sendError(int sc, String msg) throws IOException { - doSaveContext(); - super.sendError(sc, msg); - } - - /** - * Makes sure the context is stored before calling the - * superclass sendRedirect() - */ - @Override - public final void sendRedirect(String location) throws IOException { - doSaveContext(); - super.sendRedirect(location); - } - - /** - * Makes sure the context is stored before calling getOutputStream().close() or - * getOutputStream().flush() - */ - @Override - public ServletOutputStream getOutputStream() throws IOException { - return new SaveContextServletOutputStream(super.getOutputStream()); - } - - /** - * Makes sure the context is stored before calling getWriter().close() or - * getWriter().flush() - */ - @Override - public PrintWriter getWriter() throws IOException { - return new SaveContextPrintWriter(super.getWriter()); - } - - /** - * Makes sure the context is stored before calling the - * superclass flushBuffer() - */ - @Override - public void flushBuffer() throws IOException { - doSaveContext(); - super.flushBuffer(); - } - /** * Calls saveContext() with the current contents of the - * SecurityContextHolder as long as - * {@link #disableSaveOnResponseCommitted()()} was not invoked. + * SecurityContextHolder as long as {@link #disableSaveOnResponseCommitted() + * ()} was not invoked. */ - private void doSaveContext() { - if(!disableSaveOnResponseCommitted) { - saveContext(SecurityContextHolder.getContext()); - contextSaved = true; - } else if(logger.isDebugEnabled()){ - logger.debug("Skip saving SecurityContext since saving on response commited is disabled"); - } + @Override + protected void onResponseCommitted() { + saveContext(SecurityContextHolder.getContext()); + contextSaved = true; } @Override @@ -182,283 +119,10 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends HttpServ } /** - * Tells if the response wrapper has called saveContext() because of this wrapper. + * Tells if the response wrapper has called saveContext() because of this + * wrapper. */ public final boolean isContextSaved() { return contextSaved; } - - /** - * Ensures the {@link SecurityContext} is updated prior to methods that commit the response. We delegate all methods - * to the original {@link PrintWriter} to ensure that the behavior is as close to the original {@link PrintWriter} - * as possible. See SEC-2039 - * @author Rob Winch - */ - private class SaveContextPrintWriter extends PrintWriter { - private final PrintWriter delegate; - - public SaveContextPrintWriter(PrintWriter delegate) { - super(delegate); - this.delegate = delegate; - } - - public void flush() { - doSaveContext(); - delegate.flush(); - } - - public void close() { - doSaveContext(); - delegate.close(); - } - - public int hashCode() { - return delegate.hashCode(); - } - - public boolean equals(Object obj) { - return delegate.equals(obj); - } - - public String toString() { - return getClass().getName() + "[delegate=" + delegate.toString() + "]"; - } - - public boolean checkError() { - return delegate.checkError(); - } - - public void write(int c) { - delegate.write(c); - } - - public void write(char[] buf, int off, int len) { - delegate.write(buf, off, len); - } - - public void write(char[] buf) { - delegate.write(buf); - } - - public void write(String s, int off, int len) { - delegate.write(s, off, len); - } - - public void write(String s) { - delegate.write(s); - } - - public void print(boolean b) { - delegate.print(b); - } - - public void print(char c) { - delegate.print(c); - } - - public void print(int i) { - delegate.print(i); - } - - public void print(long l) { - delegate.print(l); - } - - public void print(float f) { - delegate.print(f); - } - - public void print(double d) { - delegate.print(d); - } - - public void print(char[] s) { - delegate.print(s); - } - - public void print(String s) { - delegate.print(s); - } - - public void print(Object obj) { - delegate.print(obj); - } - - public void println() { - delegate.println(); - } - - public void println(boolean x) { - delegate.println(x); - } - - public void println(char x) { - delegate.println(x); - } - - public void println(int x) { - delegate.println(x); - } - - public void println(long x) { - delegate.println(x); - } - - public void println(float x) { - delegate.println(x); - } - - public void println(double x) { - delegate.println(x); - } - - public void println(char[] x) { - delegate.println(x); - } - - public void println(String x) { - delegate.println(x); - } - - public void println(Object x) { - delegate.println(x); - } - - public PrintWriter printf(String format, Object... args) { - return delegate.printf(format, args); - } - - public PrintWriter printf(Locale l, String format, Object... args) { - return delegate.printf(l, format, args); - } - - public PrintWriter format(String format, Object... args) { - return delegate.format(format, args); - } - - public PrintWriter format(Locale l, String format, Object... args) { - return delegate.format(l, format, args); - } - - public PrintWriter append(CharSequence csq) { - return delegate.append(csq); - } - - public PrintWriter append(CharSequence csq, int start, int end) { - return delegate.append(csq, start, end); - } - - public PrintWriter append(char c) { - return delegate.append(c); - } - } - - /** - * Ensures the {@link SecurityContext} is updated prior to methods that commit the response. We delegate all methods - * to the original {@link ServletOutputStream} to ensure that the behavior is as close to the original {@link ServletOutputStream} - * as possible. See SEC-2039 - * - * @author Rob Winch - */ - private class SaveContextServletOutputStream extends ServletOutputStream { - private final ServletOutputStream delegate; - - public SaveContextServletOutputStream(ServletOutputStream delegate) { - this.delegate = delegate; - } - - public void write(int b) throws IOException { - this.delegate.write(b); - } - - public void flush() throws IOException { - doSaveContext(); - delegate.flush(); - } - - public void close() throws IOException { - doSaveContext(); - delegate.close(); - } - - public int hashCode() { - return delegate.hashCode(); - } - - public boolean equals(Object obj) { - return delegate.equals(obj); - } - - public void print(boolean b) throws IOException { - delegate.print(b); - } - - public void print(char c) throws IOException { - delegate.print(c); - } - - public void print(double d) throws IOException { - delegate.print(d); - } - - public void print(float f) throws IOException { - delegate.print(f); - } - - public void print(int i) throws IOException { - delegate.print(i); - } - - public void print(long l) throws IOException { - delegate.print(l); - } - - public void print(String arg0) throws IOException { - delegate.print(arg0); - } - - public void println() throws IOException { - delegate.println(); - } - - public void println(boolean b) throws IOException { - delegate.println(b); - } - - public void println(char c) throws IOException { - delegate.println(c); - } - - public void println(double d) throws IOException { - delegate.println(d); - } - - public void println(float f) throws IOException { - delegate.println(f); - } - - public void println(int i) throws IOException { - delegate.println(i); - } - - public void println(long l) throws IOException { - delegate.println(l); - } - - public void println(String s) throws IOException { - delegate.println(s); - } - - public void write(byte[] b) throws IOException { - delegate.write(b); - } - - public void write(byte[] b, int off, int len) throws IOException { - delegate.write(b, off, len); - } - - public String toString() { - return getClass().getName() + "[delegate=" + delegate.toString() + "]"; - } - } } diff --git a/web/src/test/java/org/springframework/security/web/context/OnCommittedResponseWrapperTests.java b/web/src/test/java/org/springframework/security/web/context/OnCommittedResponseWrapperTests.java new file mode 100644 index 0000000000..93336ca867 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/context/OnCommittedResponseWrapperTests.java @@ -0,0 +1,1122 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.context; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Locale; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServletResponse; + +import static org.fest.assertions.Assertions.assertThat; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class OnCommittedResponseWrapperTests { + private static final String NL = "\r\n"; + + @Mock + HttpServletResponse delegate; + @Mock + PrintWriter writer; + @Mock + ServletOutputStream out; + + OnCommittedResponseWrapper response; + + boolean committed; + + @Before + public void setup() throws Exception { + response = new OnCommittedResponseWrapper(delegate) { + @Override + protected void onResponseCommitted() { + committed = true; + } + }; + when(delegate.getWriter()).thenReturn(writer); + when(delegate.getOutputStream()).thenReturn(out); + } + + + // --- printwriter + + @Test + public void printWriterHashCode() throws Exception { + int expected = writer.hashCode(); + + assertThat(response.getWriter().hashCode()).isEqualTo(expected); + } + + @Test + public void printWriterCheckError() throws Exception { + boolean expected = true; + when(writer.checkError()).thenReturn(expected); + + assertThat(response.getWriter().checkError()).isEqualTo(expected); + } + + @Test + public void printWriterWriteInt() throws Exception { + int expected = 1; + + response.getWriter().write(expected); + + verify(writer).write(expected); + } + + @Test + public void printWriterWriteCharIntInt() throws Exception { + char[] buff = new char[0]; + int off = 2; + int len = 3; + + response.getWriter().write(buff,off,len); + + verify(writer).write(buff,off,len); + } + + @Test + public void printWriterWriteChar() throws Exception { + char[] buff = new char[0]; + + response.getWriter().write(buff); + + verify(writer).write(buff); + } + + @Test + public void printWriterWriteStringIntInt() throws Exception { + String s = ""; + int off = 2; + int len = 3; + + response.getWriter().write(s,off,len); + + verify(writer).write(s,off,len); + } + + @Test + public void printWriterWriteString() throws Exception { + String s = ""; + + response.getWriter().write(s); + + verify(writer).write(s); + } + + @Test + public void printWriterPrintBoolean() throws Exception { + boolean b = true; + + response.getWriter().print(b); + + verify(writer).print(b); + } + + @Test + public void printWriterPrintChar() throws Exception { + char c = 1; + + response.getWriter().print(c); + + verify(writer).print(c); + } + + @Test + public void printWriterPrintInt() throws Exception { + int i = 1; + + response.getWriter().print(i); + + verify(writer).print(i); + } + + @Test + public void printWriterPrintLong() throws Exception { + long l = 1; + + response.getWriter().print(l); + + verify(writer).print(l); + } + + @Test + public void printWriterPrintFloat() throws Exception { + float f = 1; + + response.getWriter().print(f); + + verify(writer).print(f); + } + + @Test + public void printWriterPrintDouble() throws Exception { + double x = 1; + + response.getWriter().print(x); + + verify(writer).print(x); + } + + @Test + public void printWriterPrintCharArray() throws Exception { + char[] x = new char[0]; + + response.getWriter().print(x); + + verify(writer).print(x); + } + + @Test + public void printWriterPrintString() throws Exception { + String x = "1"; + + response.getWriter().print(x); + + verify(writer).print(x); + } + + @Test + public void printWriterPrintObject() throws Exception { + Object x = "1"; + + response.getWriter().print(x); + + verify(writer).print(x); + } + + @Test + public void printWriterPrintln() throws Exception { + response.getWriter().println(); + + verify(writer).println(); + } + + @Test + public void printWriterPrintlnBoolean() throws Exception { + boolean b = true; + + response.getWriter().println(b); + + verify(writer).println(b); + } + + @Test + public void printWriterPrintlnChar() throws Exception { + char c = 1; + + response.getWriter().println(c); + + verify(writer).println(c); + } + + @Test + public void printWriterPrintlnInt() throws Exception { + int i = 1; + + response.getWriter().println(i); + + verify(writer).println(i); + } + + @Test + public void printWriterPrintlnLong() throws Exception { + long l = 1; + + response.getWriter().println(l); + + verify(writer).println(l); + } + + @Test + public void printWriterPrintlnFloat() throws Exception { + float f = 1; + + response.getWriter().println(f); + + verify(writer).println(f); + } + + @Test + public void printWriterPrintlnDouble() throws Exception { + double x = 1; + + response.getWriter().println(x); + + verify(writer).println(x); + } + + @Test + public void printWriterPrintlnCharArray() throws Exception { + char[] x = new char[0]; + + response.getWriter().println(x); + + verify(writer).println(x); + } + + @Test + public void printWriterPrintlnString() throws Exception { + String x = "1"; + + response.getWriter().println(x); + + verify(writer).println(x); + } + + @Test + public void printWriterPrintlnObject() throws Exception { + Object x = "1"; + + response.getWriter().println(x); + + verify(writer).println(x); + } + + @Test + public void printWriterPrintfStringObjectVargs() throws Exception { + String format = "format"; + Object[] args = new Object[] { "1" }; + + response.getWriter().printf(format, args); + + verify(writer).printf(format, args); + } + + @Test + public void printWriterPrintfLocaleStringObjectVargs() throws Exception { + Locale l = Locale.US; + String format = "format"; + Object[] args = new Object[] { "1" }; + + response.getWriter().printf(l, format, args); + + verify(writer).printf(l, format, args); + } + + @Test + public void printWriterFormatStringObjectVargs() throws Exception { + String format = "format"; + Object[] args = new Object[] { "1" }; + + response.getWriter().format(format, args); + + verify(writer).format(format, args); + } + + @Test + public void printWriterFormatLocaleStringObjectVargs() throws Exception { + Locale l = Locale.US; + String format = "format"; + Object[] args = new Object[] { "1" }; + + response.getWriter().format(l, format, args); + + verify(writer).format(l, format, args); + } + + + @Test + public void printWriterAppendCharSequence() throws Exception { + String x = "a"; + + response.getWriter().append(x); + + verify(writer).append(x); + } + + @Test + public void printWriterAppendCharSequenceIntInt() throws Exception { + String x = "abcdef"; + int start = 1; + int end = 3; + + response.getWriter().append(x, start, end); + + verify(writer).append(x, start, end); + } + + + @Test + public void printWriterAppendChar() throws Exception { + char x = 1; + + response.getWriter().append(x); + + verify(writer).append(x); + } + + // servletoutputstream + + + @Test + public void outputStreamHashCode() throws Exception { + int expected = out.hashCode(); + + assertThat(response.getOutputStream().hashCode()).isEqualTo(expected); + } + + @Test + public void outputStreamWriteInt() throws Exception { + int expected = 1; + + response.getOutputStream().write(expected); + + verify(out).write(expected); + } + + @Test + public void outputStreamWriteByte() throws Exception { + byte[] expected = new byte[0]; + + response.getOutputStream().write(expected); + + verify(out).write(expected); + } + + @Test + public void outputStreamWriteByteIntInt() throws Exception { + int start = 1; + int end = 2; + byte[] expected = new byte[0]; + + response.getOutputStream().write(expected, start, end); + + verify(out).write(expected, start, end); + } + + @Test + public void outputStreamPrintBoolean() throws Exception { + boolean b = true; + + response.getOutputStream().print(b); + + verify(out).print(b); + } + + @Test + public void outputStreamPrintChar() throws Exception { + char c = 1; + + response.getOutputStream().print(c); + + verify(out).print(c); + } + + @Test + public void outputStreamPrintInt() throws Exception { + int i = 1; + + response.getOutputStream().print(i); + + verify(out).print(i); + } + + @Test + public void outputStreamPrintLong() throws Exception { + long l = 1; + + response.getOutputStream().print(l); + + verify(out).print(l); + } + + @Test + public void outputStreamPrintFloat() throws Exception { + float f = 1; + + response.getOutputStream().print(f); + + verify(out).print(f); + } + + @Test + public void outputStreamPrintDouble() throws Exception { + double x = 1; + + response.getOutputStream().print(x); + + verify(out).print(x); + } + + @Test + public void outputStreamPrintString() throws Exception { + String x = "1"; + + response.getOutputStream().print(x); + + verify(out).print(x); + } + + @Test + public void outputStreamPrintln() throws Exception { + response.getOutputStream().println(); + + verify(out).println(); + } + + @Test + public void outputStreamPrintlnBoolean() throws Exception { + boolean b = true; + + response.getOutputStream().println(b); + + verify(out).println(b); + } + + @Test + public void outputStreamPrintlnChar() throws Exception { + char c = 1; + + response.getOutputStream().println(c); + + verify(out).println(c); + } + + @Test + public void outputStreamPrintlnInt() throws Exception { + int i = 1; + + response.getOutputStream().println(i); + + verify(out).println(i); + } + + @Test + public void outputStreamPrintlnLong() throws Exception { + long l = 1; + + response.getOutputStream().println(l); + + verify(out).println(l); + } + + @Test + public void outputStreamPrintlnFloat() throws Exception { + float f = 1; + + response.getOutputStream().println(f); + + verify(out).println(f); + } + + @Test + public void outputStreamPrintlnDouble() throws Exception { + double x = 1; + + response.getOutputStream().println(x); + + verify(out).println(x); + } + + @Test + public void outputStreamPrintlnString() throws Exception { + String x = "1"; + + response.getOutputStream().println(x); + + verify(out).println(x); + } + + // The amount of content specified in the setContentLength method of the response + // has been greater than zero and has been written to the response. + + @Test + public void contentLengthPrintWriterWriteIntCommits() throws Exception { + int expected = 1; + response.setContentLength(String.valueOf(expected).length()); + + response.getWriter().write(expected); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterWriteIntMultiDigitCommits() throws Exception { + int expected = 10000; + response.setContentLength(String.valueOf(expected).length()); + + response.getWriter().write(expected); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPlus1PrintWriterWriteIntMultiDigitCommits() throws Exception { + int expected = 10000; + response.setContentLength(String.valueOf(expected).length() + 1); + + response.getWriter().write(expected); + + assertThat(committed).isFalse(); + + response.getWriter().write(1); + + assertThat(committed).isTrue(); + } + + + @Test + public void contentLengthPrintWriterWriteCharIntIntCommits() throws Exception { + char[] buff = new char[0]; + int off = 2; + int len = 3; + response.setContentLength(3); + + response.getWriter().write(buff,off,len); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterWriteCharCommits() throws Exception { + char[] buff = new char[4]; + response.setContentLength(buff.length); + + response.getWriter().write(buff); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterWriteStringIntIntCommits() throws Exception { + String s = ""; + int off = 2; + int len = 3; + response.setContentLength(3); + + response.getWriter().write(s,off,len); + + assertThat(committed).isTrue(); + } + + + @Test + public void contentLengthPrintWriterWriteStringCommits() throws IOException { + String body = "something"; + response.setContentLength(body.length()); + + response.getWriter().write(body); + + assertThat(committed).isTrue(); + } + + @Test + public void printWriterWriteStringContentLengthCommits() throws IOException { + String body = "something"; + response.getWriter().write(body); + + response.setContentLength(body.length()); + + assertThat(committed).isTrue(); + } + + @Test + public void printWriterWriteStringDoesNotCommit() throws IOException { + String body = "something"; + + response.getWriter().write(body); + + assertThat(committed).isFalse(); + } + + @Test + public void contentLengthPrintWriterPrintBooleanCommits() throws Exception { + boolean b = true; + response.setContentLength(1); + + response.getWriter().print(b); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintCharCommits() throws Exception { + char c = 1; + response.setContentLength(1); + + response.getWriter().print(c); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintIntCommits() throws Exception { + int i = 1234; + response.setContentLength(String.valueOf(i).length()); + + response.getWriter().print(i); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintLongCommits() throws Exception { + long l = 12345; + response.setContentLength(String.valueOf(l).length()); + + response.getWriter().print(l); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintFloatCommits() throws Exception { + float f = 12345; + response.setContentLength(String.valueOf(f).length()); + + response.getWriter().print(f); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintDoubleCommits() throws Exception { + double x = 1.2345; + response.setContentLength(String.valueOf(x).length()); + + response.getWriter().print(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintCharArrayCommits() throws Exception { + char[] x = new char[10]; + response.setContentLength(x.length); + + response.getWriter().print(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintStringCommits() throws Exception { + String x = "12345"; + response.setContentLength(x.length()); + + response.getWriter().print(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintObjectCommits() throws Exception { + Object x = "12345"; + response.setContentLength(String.valueOf(x).length()); + + response.getWriter().print(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnCommits() throws Exception { + response.setContentLength(NL.length()); + + response.getWriter().println(); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnBooleanCommits() throws Exception { + boolean b = true; + response.setContentLength(1); + + response.getWriter().println(b); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnCharCommits() throws Exception { + char c = 1; + response.setContentLength(1); + + response.getWriter().println(c); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnIntCommits() throws Exception { + int i = 12345; + response.setContentLength(String.valueOf(i).length()); + + response.getWriter().println(i); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnLongCommits() throws Exception { + long l = 12345678; + response.setContentLength(String.valueOf(l).length()); + + response.getWriter().println(l); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnFloatCommits() throws Exception { + float f = 1234; + response.setContentLength(String.valueOf(f).length()); + + response.getWriter().println(f); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnDoubleCommits() throws Exception { + double x = 1; + response.setContentLength(String.valueOf(x).length()); + + response.getWriter().println(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnCharArrayCommits() throws Exception { + char[] x = new char[20]; + response.setContentLength(x.length); + + response.getWriter().println(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnStringCommits() throws Exception { + String x = "1"; + response.setContentLength(String.valueOf(x).length()); + + response.getWriter().println(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterPrintlnObjectCommits() throws Exception { + Object x = "1"; + response.setContentLength(String.valueOf(x).length()); + + response.getWriter().println(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterAppendCharSequenceCommits() throws Exception { + String x = "a"; + response.setContentLength(String.valueOf(x).length()); + + response.getWriter().append(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterAppendCharSequenceIntIntCommits() throws Exception { + String x = "abcdef"; + int start = 1; + int end = 3; + response.setContentLength(end - start); + + response.getWriter().append(x, start, end); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPrintWriterAppendCharCommits() throws Exception { + char x = 1; + response.setContentLength(1); + + response.getWriter().append(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamWriteIntCommits() throws Exception { + int expected = 1; + response.setContentLength(String.valueOf(expected).length()); + + response.getOutputStream().write(expected); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamWriteIntMultiDigitCommits() throws Exception { + int expected = 10000; + response.setContentLength(String.valueOf(expected).length()); + + response.getOutputStream().write(expected); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthPlus1OutputStreamWriteIntMultiDigitCommits() throws Exception { + int expected = 10000; + response.setContentLength(String.valueOf(expected).length() + 1); + + response.getOutputStream().write(expected); + + assertThat(committed).isFalse(); + + response.getOutputStream().write(1); + + assertThat(committed).isTrue(); + } + + // gh-171 + @Test + public void contentLengthPlus1OutputStreamWriteByteArrayMultiDigitCommits() throws Exception { + String expected = "{\n" + + " \"parameterName\" : \"_csrf\",\n" + + " \"token\" : \"06300b65-c4aa-4c8f-8cda-39ee17f545a0\",\n" + + " \"headerName\" : \"X-CSRF-TOKEN\"\n" + + "}"; + response.setContentLength(expected.length() + 1); + + response.getOutputStream().write(expected.getBytes()); + + assertThat(committed).isFalse(); + + response.getOutputStream().write("1".getBytes("UTF-8")); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintBooleanCommits() throws Exception { + boolean b = true; + response.setContentLength(1); + + response.getOutputStream().print(b); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintCharCommits() throws Exception { + char c = 1; + response.setContentLength(1); + + response.getOutputStream().print(c); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintIntCommits() throws Exception { + int i = 1234; + response.setContentLength(String.valueOf(i).length()); + + response.getOutputStream().print(i); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintLongCommits() throws Exception { + long l = 12345; + response.setContentLength(String.valueOf(l).length()); + + response.getOutputStream().print(l); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintFloatCommits() throws Exception { + float f = 12345; + response.setContentLength(String.valueOf(f).length()); + + response.getOutputStream().print(f); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintDoubleCommits() throws Exception { + double x = 1.2345; + response.setContentLength(String.valueOf(x).length()); + + response.getOutputStream().print(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintStringCommits() throws Exception { + String x = "12345"; + response.setContentLength(x.length()); + + response.getOutputStream().print(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintlnCommits() throws Exception { + response.setContentLength(NL.length()); + + response.getOutputStream().println(); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintlnBooleanCommits() throws Exception { + boolean b = true; + response.setContentLength(1); + + response.getOutputStream().println(b); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintlnCharCommits() throws Exception { + char c = 1; + response.setContentLength(1); + + response.getOutputStream().println(c); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintlnIntCommits() throws Exception { + int i = 12345; + response.setContentLength(String.valueOf(i).length()); + + response.getOutputStream().println(i); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintlnLongCommits() throws Exception { + long l = 12345678; + response.setContentLength(String.valueOf(l).length()); + + response.getOutputStream().println(l); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintlnFloatCommits() throws Exception { + float f = 1234; + response.setContentLength(String.valueOf(f).length()); + + response.getOutputStream().println(f); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintlnDoubleCommits() throws Exception { + double x = 1; + response.setContentLength(String.valueOf(x).length()); + + response.getOutputStream().println(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthOutputStreamPrintlnStringCommits() throws Exception { + String x = "1"; + response.setContentLength(String.valueOf(x).length()); + + response.getOutputStream().println(x); + + assertThat(committed).isTrue(); + } + + @Test + public void contentLengthDoesNotCommit() throws IOException { + String body = "something"; + + response.setContentLength(body.length()); + + assertThat(committed).isFalse(); + } + + @Test + public void contentLengthOutputStreamWriteStringCommits() throws IOException { + String body = "something"; + response.setContentLength(body.length()); + + response.getOutputStream().print(body); + + assertThat(committed).isTrue(); + } + + @Test + public void addHeaderContentLengthPrintWriterWriteStringCommits() throws Exception { + int expected = 1234; + response.addHeader("Content-Length",String.valueOf(String.valueOf(expected).length())); + + response.getWriter().write(expected); + + assertThat(committed).isTrue(); + } + + @Test + public void bufferSizePrintWriterWriteCommits() throws Exception { + String expected = "1234567890"; + when(response.getBufferSize()).thenReturn(expected.length()); + + response.getWriter().write(expected); + + assertThat(committed).isTrue(); + } + + @Test + public void bufferSizeCommitsOnce() throws Exception { + String expected = "1234567890"; + when(response.getBufferSize()).thenReturn(expected.length()); + + response.getWriter().write(expected); + + assertThat(committed).isTrue(); + + committed = false; + + response.getWriter().write(expected); + + assertThat(committed).isFalse(); + } +} \ No newline at end of file