diff --git a/web/src/main/java/org/springframework/security/web/util/OnCommittedResponseWrapper.java b/web/src/main/java/org/springframework/security/web/util/OnCommittedResponseWrapper.java index ecc22724d6..abc57bc175 100644 --- a/web/src/main/java/org/springframework/security/web/util/OnCommittedResponseWrapper.java +++ b/web/src/main/java/org/springframework/security/web/util/OnCommittedResponseWrapper.java @@ -58,37 +58,37 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap } @Override - public void addHeader(String name, String value) { + public void addHeader(@Nullable String name, @Nullable String value) { checkContentLengthHeader(name, value); super.addHeader(name, value); } @Override - public void addIntHeader(String name, int value) { + public void addIntHeader(@Nullable String name, int value) { checkContentLengthHeader(name, value); super.addIntHeader(name, value); } @Override - public void setHeader(String name, String value) { + public void setHeader(@Nullable String name, @Nullable String value) { checkContentLengthHeader(name, value); super.setHeader(name, value); } @Override - public void setIntHeader(String name, int value) { + public void setIntHeader(@Nullable String name, int value) { checkContentLengthHeader(name, value); super.setIntHeader(name, value); } - private void checkContentLengthHeader(String name, int value) { + private void checkContentLengthHeader(@Nullable String name, int value) { if ("Content-Length".equalsIgnoreCase(name)) { setContentLength(value); } } - private void checkContentLengthHeader(String name, String value) { - if ("Content-Length".equalsIgnoreCase(name)) { + private void checkContentLengthHeader(@Nullable String name, @Nullable String value) { + if (value != null && "Content-Length".equalsIgnoreCase(name)) { setContentLength(Long.parseLong(value)); } } @@ -150,7 +150,7 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap * before calling the superclass sendError() */ @Override - public final void sendError(int sc, String msg) throws IOException { + public final void sendError(int sc, @Nullable String msg) throws IOException { doOnResponseCommitted(); super.sendError(sc, msg); } @@ -160,7 +160,7 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap * before calling the superclass sendRedirect() */ @Override - public final void sendRedirect(String location) throws IOException { + public final void sendRedirect(@Nullable String location) throws IOException { doOnResponseCommitted(); super.sendRedirect(location); } @@ -207,7 +207,7 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap } } - private void trackContentLength(Object content) { + private void trackContentLength(@Nullable Object content) { if (!this.disableOnCommitted) { trackContentLength(String.valueOf(content)); } @@ -249,7 +249,7 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap } } - private void trackContentLength(String content) { + private void trackContentLength(@Nullable String content) { if (!this.disableOnCommitted) { int contentLength = (content != null) ? content.length() : 4; checkContentLength(contentLength); @@ -356,7 +356,7 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap } @Override - public void write(String s) { + public void write(@Nullable String s) { trackContentLength(s); this.delegate.write(s); } @@ -404,13 +404,13 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap } @Override - public void print(String s) { + public void print(@Nullable String s) { trackContentLength(s); this.delegate.print(s); } @Override - public void print(Object obj) { + public void print(@Nullable Object obj) { trackContentLength(obj); this.delegate.print(obj); } @@ -471,14 +471,14 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap } @Override - public void println(String x) { + public void println(@Nullable String x) { trackContentLength(x); trackContentLengthLn(); this.delegate.println(x); } @Override - public void println(Object x) { + public void println(@Nullable Object x) { trackContentLength(x); trackContentLengthLn(); this.delegate.println(x); @@ -505,13 +505,13 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap } @Override - public PrintWriter append(CharSequence csq) { - checkContentLength(csq.length()); + public PrintWriter append(@Nullable CharSequence csq) { + checkContentLength((csq != null) ? csq.length() : 4); return this.delegate.append(csq); } @Override - public PrintWriter append(CharSequence csq, int start, int end) { + public PrintWriter append(@Nullable CharSequence csq, int start, int end) { checkContentLength(end - start); return this.delegate.append(csq, start, end); } @@ -596,7 +596,7 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap } @Override - public void print(String s) throws IOException { + public void print(@Nullable String s) throws IOException { trackContentLength(s); this.delegate.print(s); } @@ -650,7 +650,7 @@ public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrap } @Override - public void println(String s) throws IOException { + public void println(@Nullable String s) throws IOException { trackContentLength(s); trackContentLengthLn(); this.delegate.println(s); diff --git a/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java b/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java index bbdb50b20c..94002ee650 100644 --- a/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java +++ b/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java @@ -28,7 +28,10 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.http.HttpHeaders; + import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; @@ -1006,6 +1009,16 @@ public class OnCommittedResponseWrapperTests { assertThat(this.committed).isTrue(); } + @Test + public void addHeaderNullNameDoesNotThrow() { + assertThatNoException().isThrownBy(() -> this.response.addHeader(null, "value")); + } + + @Test + public void addHeaderNullValueDoesNotThrow() { + assertThatNoException().isThrownBy(() -> this.response.addHeader(HttpHeaders.CONTENT_LENGTH, null)); + } + @Test public void addIntHeaderContentLengthPrintWriterWriteStringCommits() throws Exception { givenGetWriterThenReturn(); @@ -1015,6 +1028,11 @@ public class OnCommittedResponseWrapperTests { assertThat(this.committed).isTrue(); } + @Test + public void addIntHeaderNullNameDoesNotThrow() { + assertThatNoException().isThrownBy(() -> this.response.addIntHeader(null, 1)); + } + @Test public void setHeaderContentLengthPrintWriterWriteStringCommits() throws Exception { givenGetWriterThenReturn(); @@ -1024,6 +1042,16 @@ public class OnCommittedResponseWrapperTests { assertThat(this.committed).isTrue(); } + @Test + public void setHeaderNullNameDoesNotThrow() { + assertThatNoException().isThrownBy(() -> this.response.setHeader(null, "value")); + } + + @Test + public void setHeaderNullValueDoesNotThrow() { + assertThatNoException().isThrownBy(() -> this.response.setHeader(HttpHeaders.CONTENT_LENGTH, null)); + } + @Test public void setIntHeaderContentLengthPrintWriterWriteStringCommits() throws Exception { givenGetWriterThenReturn(); @@ -1033,6 +1061,11 @@ public class OnCommittedResponseWrapperTests { assertThat(this.committed).isTrue(); } + @Test + public void setIntHeaderNullNameDoesNotThrow() { + assertThatNoException().isThrownBy(() -> this.response.setIntHeader(null, 1)); + } + @Test public void bufferSizePrintWriterWriteCommits() throws Exception { givenGetWriterThenReturn(); @@ -1054,4 +1087,86 @@ public class OnCommittedResponseWrapperTests { assertThat(this.committed).isFalse(); } + @Test + public void printWriterPrintNullStringDoesNotThrow() throws Exception { + givenGetWriterThenReturn(); + String s = null; + assertThatNoException().isThrownBy(() -> this.response.getWriter().print(s)); + verify(this.writer).print(s); + } + + @Test + public void printWriterPrintlnNullStringDoesNotThrow() throws Exception { + givenGetWriterThenReturn(); + String s = null; + assertThatNoException().isThrownBy(() -> this.response.getWriter().println(s)); + verify(this.writer).println(s); + } + + @Test + public void printWriterPrintNullObjectDoesNotThrow() throws Exception { + givenGetWriterThenReturn(); + Object obj = null; + assertThatNoException().isThrownBy(() -> this.response.getWriter().print(obj)); + verify(this.writer).print(obj); + } + + @Test + public void printWriterPrintlnNullObjectDoesNotThrow() throws Exception { + givenGetWriterThenReturn(); + Object obj = null; + assertThatNoException().isThrownBy(() -> this.response.getWriter().println(obj)); + verify(this.writer).println(obj); + } + + @Test + public void printWriterWriteNullStringDoesNotThrow() throws Exception { + givenGetWriterThenReturn(); + String s = null; + assertThatNoException().isThrownBy(() -> this.response.getWriter().write(s)); + verify(this.writer).write(s); + } + + @Test + public void printWriterAppendNullCharSequenceDoesNotThrow() throws Exception { + givenGetWriterThenReturn(); + CharSequence csq = null; + assertThatNoException().isThrownBy(() -> this.response.getWriter().append(csq)); + verify(this.writer).append(csq); + } + + @Test + public void printWriterAppendNullCharSequenceIntIntDoesNotThrow() throws Exception { + givenGetWriterThenReturn(); + CharSequence csq = null; + assertThatNoException().isThrownBy(() -> this.response.getWriter().append(csq, 0, 3)); + verify(this.writer).append(csq, 0, 3); + } + + @Test + public void outputStreamPrintNullStringDoesNotThrow() throws Exception { + givenGetOutputStreamThenReturn(); + String s = null; + assertThatNoException().isThrownBy(() -> this.response.getOutputStream().print(s)); + verify(this.out).print(s); + } + + @Test + public void outputStreamPrintlnNullStringDoesNotThrow() throws Exception { + givenGetOutputStreamThenReturn(); + String s = null; + assertThatNoException().isThrownBy(() -> this.response.getOutputStream().println(s)); + verify(this.out).println(s); + } + + @Test + public void sendErrorWithNullMsgDoesNotThrow() throws Exception { + assertThatNoException().isThrownBy(() -> this.response.sendError(400, null)); + } + + @Test + public void sendRedirectWithNullLocationDoesNotThrow() throws Exception { + assertThatNoException().isThrownBy(() -> this.response.sendRedirect(null)); + } + }