diff --git a/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java b/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java index d963c79748..a06d1b93d2 100644 --- a/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java +++ b/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,12 @@ import java.io.IOException; import java.util.List; import javax.servlet.FilterChain; +import javax.servlet.RequestDispatcher; import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import org.springframework.security.web.util.OnCommittedResponseWrapper; @@ -33,6 +37,7 @@ import org.springframework.web.filter.OncePerRequestFilter; * and X-Content-Type-Options. * * @author Marten Deinum + * @author Josh Cummings * @since 3.2 * */ @@ -62,8 +67,11 @@ public class HeaderWriterFilter extends OncePerRequestFilter { HeaderWriterResponse headerWriterResponse = new HeaderWriterResponse(request, response, this.headerWriters); + HeaderWriterRequest headerWriterRequest = new HeaderWriterRequest(request, + headerWriterResponse); + try { - filterChain.doFilter(request, headerWriterResponse); + filterChain.doFilter(headerWriterRequest, headerWriterResponse); } finally { headerWriterResponse.writeHeaders(); @@ -106,4 +114,39 @@ public class HeaderWriterFilter extends OncePerRequestFilter { return (HttpServletResponse) getResponse(); } } + + static class HeaderWriterRequest extends HttpServletRequestWrapper { + private final HeaderWriterResponse response; + + HeaderWriterRequest(HttpServletRequest request, HeaderWriterResponse response) { + super(request); + this.response = response; + } + + @Override + public RequestDispatcher getRequestDispatcher(String path) { + return new HeaderWriterRequestDispatcher(super.getRequestDispatcher(path), this.response); + } + } + + static class HeaderWriterRequestDispatcher implements RequestDispatcher { + private final RequestDispatcher delegate; + private final HeaderWriterResponse response; + + HeaderWriterRequestDispatcher(RequestDispatcher delegate, HeaderWriterResponse response) { + this.delegate = delegate; + this.response = response; + } + + @Override + public void forward(ServletRequest request, ServletResponse response) throws ServletException, IOException { + this.delegate.forward(request, response); + } + + @Override + public void include(ServletRequest request, ServletResponse response) throws ServletException, IOException { + this.response.onResponseCommitted(); + this.delegate.include(request, response); + } + } } diff --git a/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java b/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java index 1c5a1a5169..07a84b6352 100644 --- a/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.security.web.header; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import javax.servlet.FilterChain; @@ -84,7 +85,9 @@ public class HeaderWriterFilterTests { verify(this.writer1).writeHeaders(request, response); verify(this.writer2).writeHeaders(request, response); - assertThat(filterChain.getRequest()).isEqualTo(request); // verify the filterChain + HeaderWriterFilter.HeaderWriterRequest wrappedRequest = (HeaderWriterFilter.HeaderWriterRequest) + filterChain.getRequest(); + assertThat(wrappedRequest.getRequest()).isEqualTo(request); // verify the filterChain // continued } @@ -112,4 +115,25 @@ public class HeaderWriterFilterTests { verifyNoMoreInteractions(this.writer1); } + + // gh-5499 + @Test + public void doFilterWhenRequestContainsIncludeThenHeadersStillWritten() throws Exception { + HeaderWriterFilter filter = new HeaderWriterFilter( + Collections.singletonList(this.writer1)); + + MockHttpServletRequest mockRequest = new MockHttpServletRequest(); + MockHttpServletResponse mockResponse = new MockHttpServletResponse(); + + filter.doFilter(mockRequest, mockResponse, (request, response) -> { + verifyZeroInteractions(HeaderWriterFilterTests.this.writer1); + + request.getRequestDispatcher("/").include(request, response); + + verify(HeaderWriterFilterTests.this.writer1).writeHeaders( + any(HttpServletRequest.class), any(HttpServletResponse.class)); + }); + + verifyNoMoreInteractions(this.writer1); + } }