SEC-2232: HeaderFactory to HeaderWriter

This commit is contained in:
Rob Winch 2013-07-26 09:01:12 -05:00
parent fd754c5cab
commit 8acd205486
11 changed files with 126 additions and 142 deletions

View File

@ -22,7 +22,7 @@ import org.springframework.beans.factory.support.ManagedList;
import org.springframework.beans.factory.xml.BeanDefinitionParser; import org.springframework.beans.factory.xml.BeanDefinitionParser;
import org.springframework.beans.factory.xml.ParserContext; import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.security.web.headers.HeadersFilter; import org.springframework.security.web.headers.HeadersFilter;
import org.springframework.security.web.headers.StaticHeaderFactory; import org.springframework.security.web.headers.StaticHeadersWriter;
import org.springframework.security.web.headers.frameoptions.*; import org.springframework.security.web.headers.frameoptions.*;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils; import org.springframework.util.xml.DomUtils;
@ -85,7 +85,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
if (StringUtils.hasText(headerFactoryRef)) { if (StringUtils.hasText(headerFactoryRef)) {
headerFactories.add(new RuntimeBeanReference(headerFactoryRef)); headerFactories.add(new RuntimeBeanReference(headerFactoryRef));
} else { } else {
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeaderFactory.class); BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class);
builder.addConstructorArgValue(headerElt.getAttribute(ATT_NAME)); builder.addConstructorArgValue(headerElt.getAttribute(ATT_NAME));
builder.addConstructorArgValue(headerElt.getAttribute(ATT_VALUE)); builder.addConstructorArgValue(headerElt.getAttribute(ATT_VALUE));
headerFactories.add(builder.getBeanDefinition()); headerFactories.add(builder.getBeanDefinition());
@ -96,7 +96,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
private void parseContentTypeOptionsElement(Element element) { private void parseContentTypeOptionsElement(Element element) {
Element contentTypeElt = DomUtils.getChildElementByTagName(element, CONTENT_TYPE_ELEMENT); Element contentTypeElt = DomUtils.getChildElementByTagName(element, CONTENT_TYPE_ELEMENT);
if (contentTypeElt != null) { if (contentTypeElt != null) {
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeaderFactory.class); BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class);
builder.addConstructorArgValue(CONTENT_TYPE_OPTIONS_HEADER); builder.addConstructorArgValue(CONTENT_TYPE_OPTIONS_HEADER);
builder.addConstructorArgValue("nosniff"); builder.addConstructorArgValue("nosniff");
headerFactories.add(builder.getBeanDefinition()); headerFactories.add(builder.getBeanDefinition());
@ -104,7 +104,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
} }
private void parseFrameOptionsElement(Element element, ParserContext parserContext) { private void parseFrameOptionsElement(Element element, ParserContext parserContext) {
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(FrameOptionsHeaderFactory.class); BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(FrameOptionsHeaderWriter.class);
Element frameElt = DomUtils.getChildElementByTagName(element, FRAME_OPTIONS_ELEMENT); Element frameElt = DomUtils.getChildElementByTagName(element, FRAME_OPTIONS_ELEMENT);
if (frameElt != null) { if (frameElt != null) {
@ -170,7 +170,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
} else if (!enabled && block) { } else if (!enabled && block) {
parserContext.getReaderContext().error("<xss-protection enabled=\"false\"/> does not allow block=\"true\".", xssElt); parserContext.getReaderContext().error("<xss-protection enabled=\"false\"/> does not allow block=\"true\".", xssElt);
} }
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeaderFactory.class); BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class);
builder.addConstructorArgValue(XSS_PROTECTION_HEADER); builder.addConstructorArgValue(XSS_PROTECTION_HEADER);
builder.addConstructorArgValue(value); builder.addConstructorArgValue(value);
headerFactories.add(builder.getBeanDefinition()); headerFactories.add(builder.getBeanDefinition());

View File

@ -418,7 +418,7 @@
</section> </section>
<section xml:id="nsa-header-ref"> <section xml:id="nsa-header-ref">
<title><literal>header-ref</literal></title> <title><literal>header-ref</literal></title>
<para>Reference to a custom implementation of the <classname>HeaderFactory</classname> interface.</para> <para>Reference to a custom implementation of the <classname>HeaderWriter</classname> interface.</para>
</section> </section>
</section> </section>
<section xml:id="nsa-header-parents"> <section xml:id="nsa-header-parents">

View File

@ -1,23 +0,0 @@
package org.springframework.security.web.headers;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Contract for a factory that creates {@code Header} instances.
*
* @author Marten Deinum
* @since 3.2
* @see HeadersFilter
*/
public interface HeaderFactory {
/**
* Create a {@code Header} instance.
*
* @param request the request
* @param response the response
* @return the created Header or <code>null</code>
*/
Header create(HttpServletRequest request, HttpServletResponse response);
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.headers;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Contract for a factory that creates {@code Header} instances.
*
* @see HeadersFilter
*
* @author Marten Deinum
* @author Rob Winch
* @since 3.2
*/
public interface HeaderWriter {
/**
* Create a {@code Header} instance.
*
* @param request the request
* @param response the response
*/
void writeHeaders(HttpServletRequest request, HttpServletResponse response);
}

View File

@ -34,10 +34,10 @@ import java.util.*;
*/ */
public class HeadersFilter extends OncePerRequestFilter { public class HeadersFilter extends OncePerRequestFilter {
/** Collection of HeaderFactory instances to produce Headers. */ /** Collection of {@link HeaderWriter} instances to write out the headers to the response . */
private final List<HeaderFactory> factories; private final List<HeaderWriter> factories;
public HeadersFilter(List<HeaderFactory> factories) { public HeadersFilter(List<HeaderWriter> factories) {
this.factories = factories; this.factories = factories;
} }
@ -45,28 +45,8 @@ public class HeadersFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
for (HeaderFactory factory : factories) { for (HeaderWriter factory : factories) {
Header header = factory.create(request, response); factory.writeHeaders(request, response);
if (header != null) {
String name = header.getName();
String[] values = header.getValues();
boolean first = true;
for (String value : values) {
if (logger.isDebugEnabled()) {
logger.debug("Adding header '" + name + "' with value '"+value +"'");
}
if (first) {
response.setHeader(name, value);
first = false;
} else {
response.addHeader(name, value);
}
}
} else {
if (logger.isDebugEnabled()) {
logger.debug("Factory produced no header.");
}
}
} }
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
} }

View File

@ -6,23 +6,25 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
* {@code HeaderFactory} implementation which returns the same {@code Header} instance. * {@code HeaderWriter} implementation which writes the same {@code Header} instance.
* *
* @author Marten Deinum * @author Marten Deinum
* @since 3.2 * @since 3.2
*/ */
public class StaticHeaderFactory implements HeaderFactory { public class StaticHeadersWriter implements HeaderWriter {
private final Header header; private final Header header;
public StaticHeaderFactory(String name, String... values) { public StaticHeadersWriter(String name, String... values) {
Assert.hasText(name, "Header name is required"); Assert.hasText(name, "Header name is required");
Assert.notEmpty(values, "Header values cannot be null or empty"); Assert.notEmpty(values, "Header values cannot be null or empty");
Assert.noNullElements(values, "Header values cannot contain null values"); Assert.noNullElements(values, "Header values cannot contain null values");
header = new Header(name, values); header = new Header(name, values);
} }
public Header create(HttpServletRequest request, HttpServletResponse response) { public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
return header; for(String value : header.getValues()) {
response.addHeader(header.getName(), value);
}
} }
} }

View File

@ -3,7 +3,7 @@ package org.springframework.security.web.headers.frameoptions;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
/** /**
* Strategy interfaces used by the {@code FrameOptionsHeaderFactory} to determine the actual value to use for the * Strategy interfaces used by the {@code FrameOptionsHeaderWriter} to determine the actual value to use for the
* X-Frame-Options header when using the ALLOW-FROM directive. * X-Frame-Options header when using the ALLOW-FROM directive.
* *
* @author Marten Deinum * @author Marten Deinum

View File

@ -1,13 +1,12 @@
package org.springframework.security.web.headers.frameoptions; package org.springframework.security.web.headers.frameoptions;
import org.springframework.security.web.headers.Header; import org.springframework.security.web.headers.HeaderWriter;
import org.springframework.security.web.headers.HeaderFactory;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
/** /**
* {@code HeaderFactory} implementation for the X-Frame-Options headers. When using the ALLOW-FROM directive the actual * {@code HeaderWriter} implementation for the X-Frame-Options headers. When using the ALLOW-FROM directive the actual
* value is determined by a {@code AllowFromStrategy}. * value is determined by a {@code AllowFromStrategy}.
* *
* @author Marten Deinum * @author Marten Deinum
@ -15,7 +14,7 @@ import javax.servlet.http.HttpServletResponse;
* *
* @see AllowFromStrategy * @see AllowFromStrategy
*/ */
public class FrameOptionsHeaderFactory implements HeaderFactory { public class FrameOptionsHeaderWriter implements HeaderWriter {
public static final String FRAME_OPTIONS_HEADER = "X-Frame-Options"; public static final String FRAME_OPTIONS_HEADER = "X-Frame-Options";
@ -24,21 +23,21 @@ public class FrameOptionsHeaderFactory implements HeaderFactory {
private final AllowFromStrategy allowFromStrategy; private final AllowFromStrategy allowFromStrategy;
private final String mode; private final String mode;
public FrameOptionsHeaderFactory(String mode) { public FrameOptionsHeaderWriter(String mode) {
this(mode, new NullAllowFromStrategy()); this(mode, new NullAllowFromStrategy());
} }
public FrameOptionsHeaderFactory(String mode, AllowFromStrategy allowFromStrategy) { public FrameOptionsHeaderWriter(String mode, AllowFromStrategy allowFromStrategy) {
this.mode=mode; this.mode=mode;
this.allowFromStrategy=allowFromStrategy; this.allowFromStrategy=allowFromStrategy;
} }
public Header create(HttpServletRequest request, HttpServletResponse response) { public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
if (ALLOW_FROM.equals(mode)) { if (ALLOW_FROM.equals(mode)) {
String value = allowFromStrategy.apply(request); String value = allowFromStrategy.apply(request);
return new Header(FRAME_OPTIONS_HEADER, ALLOW_FROM + " " + value); response.addHeader(FRAME_OPTIONS_HEADER, ALLOW_FROM + " " + value);
} else { } else {
return new Header(FRAME_OPTIONS_HEADER, mode); response.addHeader(FRAME_OPTIONS_HEADER, mode);
} }
} }

View File

@ -15,32 +15,38 @@
*/ */
package org.springframework.security.web.headers; package org.springframework.security.web.headers;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.verify;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.*;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.matchers.JUnitMatchers.hasItems;
/** /**
* Tests for the {@code HeadersFilter} * Tests for the {@code HeadersFilter}
* *
* @author Marten Deinum * @author Marten Deinum
* @since 3.2 * @since 3.2
*/ */
@RunWith(MockitoJUnitRunner.class)
public class HeadersFilterTest { public class HeadersFilterTest {
@Mock
private HeaderWriter writer1;
@Mock
private HeaderWriter writer2;
@Test @Test
public void noHeadersConfigured() throws Exception { public void noHeadersConfigured() throws Exception {
List<HeaderFactory> factories = new ArrayList(); List<HeaderWriter> headerWriters = new ArrayList<HeaderWriter>();
HeadersFilter filter = new HeadersFilter(factories); HeadersFilter filter = new HeadersFilter(headerWriters);
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain(); MockFilterChain filterChain = new MockFilterChain();
@ -52,18 +58,11 @@ public class HeadersFilterTest {
@Test @Test
public void additionalHeadersShouldBeAddedToTheResponse() throws Exception { public void additionalHeadersShouldBeAddedToTheResponse() throws Exception {
List<HeaderFactory> factories = new ArrayList(); List<HeaderWriter> headerWriters = new ArrayList<HeaderWriter>();
MockHeaderFactory factory1 = new MockHeaderFactory(); headerWriters.add(writer1);
factory1.setName("X-Header1"); headerWriters.add(writer2);
factory1.setValue("foo");
MockHeaderFactory factory2 = new MockHeaderFactory();
factory2.setName("X-Header2");
factory2.setValue("bar");
factories.add(factory1); HeadersFilter filter = new HeadersFilter(headerWriters);
factories.add(factory2);
HeadersFilter filter = new HeadersFilter(factories);
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
@ -71,30 +70,7 @@ public class HeadersFilterTest {
filter.doFilter(request, response, filterChain); filter.doFilter(request, response, filterChain);
Collection<String> headerNames = response.getHeaderNames(); verify(writer1).writeHeaders(request, response);
assertThat(headerNames.size(), is(2)); verify(writer2).writeHeaders(request, response);
assertThat(headerNames, hasItems("X-Header1", "X-Header2"));
assertThat(response.getHeader("X-Header1"), is("foo"));
assertThat(response.getHeader("X-Header2"), is("bar"));
}
private static final class MockHeaderFactory implements HeaderFactory {
private String name;
private String value;
public Header create(HttpServletRequest request, HttpServletResponse response) {
return new Header(name, value);
}
public void setName(String name) {
this.name=name;
}
public void setValue(String value) {
this.value=value;
}
} }
} }

View File

@ -1,26 +0,0 @@
package org.springframework.security.web.headers;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertSame;
import static org.springframework.test.util.MatcherAssertionErrors.assertThat;
/**
* Test for the {@code StaticHeaderFactory}
*
* @author Marten Deinum
* @since 3.2
*/
public class StaticHeaderFactoryTest {
@Test
public void sameHeaderShouldBeReturned() {
StaticHeaderFactory factory = new StaticHeaderFactory("X-header", "foo");
Header header = factory.create(null, null);
assertThat(header.getName(), is("X-header"));
assertThat(header.getValues()[0], is("foo"));
assertSame(header, factory.create(null, null));
}
}

View File

@ -0,0 +1,37 @@
package org.springframework.security.web.headers;
import static org.fest.assertions.Assertions.assertThat;
import java.util.Arrays;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
/**
* Test for the {@code StaticHeadersWriter}
*
* @author Marten Deinum
* @since 3.2
*/
public class StaticHeaderWriterTests {
private MockHttpServletRequest request;
private MockHttpServletResponse response;
@Before
public void setup() {
request = new MockHttpServletRequest();
response = new MockHttpServletResponse();
}
@Test
public void sameHeaderShouldBeReturned() {
String headerName = "X-header";
String headerValue = "foo";
StaticHeadersWriter factory = new StaticHeadersWriter(headerName, headerValue);
factory.writeHeaders(request, response);
assertThat(response.getHeaderValues(headerName)).isEqualTo(Arrays.asList(headerValue));
}
}