SEC-2232: HeaderFactory to HeaderWriter

This commit is contained in:
Rob Winch 2013-07-26 09:01:12 -05:00
parent fd754c5cab
commit 8acd205486
11 changed files with 126 additions and 142 deletions

View File

@ -22,7 +22,7 @@ import org.springframework.beans.factory.support.ManagedList;
import org.springframework.beans.factory.xml.BeanDefinitionParser;
import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.security.web.headers.HeadersFilter;
import org.springframework.security.web.headers.StaticHeaderFactory;
import org.springframework.security.web.headers.StaticHeadersWriter;
import org.springframework.security.web.headers.frameoptions.*;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
@ -85,7 +85,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
if (StringUtils.hasText(headerFactoryRef)) {
headerFactories.add(new RuntimeBeanReference(headerFactoryRef));
} else {
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeaderFactory.class);
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class);
builder.addConstructorArgValue(headerElt.getAttribute(ATT_NAME));
builder.addConstructorArgValue(headerElt.getAttribute(ATT_VALUE));
headerFactories.add(builder.getBeanDefinition());
@ -96,7 +96,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
private void parseContentTypeOptionsElement(Element element) {
Element contentTypeElt = DomUtils.getChildElementByTagName(element, CONTENT_TYPE_ELEMENT);
if (contentTypeElt != null) {
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeaderFactory.class);
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class);
builder.addConstructorArgValue(CONTENT_TYPE_OPTIONS_HEADER);
builder.addConstructorArgValue("nosniff");
headerFactories.add(builder.getBeanDefinition());
@ -104,7 +104,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
}
private void parseFrameOptionsElement(Element element, ParserContext parserContext) {
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(FrameOptionsHeaderFactory.class);
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(FrameOptionsHeaderWriter.class);
Element frameElt = DomUtils.getChildElementByTagName(element, FRAME_OPTIONS_ELEMENT);
if (frameElt != null) {
@ -170,7 +170,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
} else if (!enabled && block) {
parserContext.getReaderContext().error("<xss-protection enabled=\"false\"/> does not allow block=\"true\".", xssElt);
}
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeaderFactory.class);
BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class);
builder.addConstructorArgValue(XSS_PROTECTION_HEADER);
builder.addConstructorArgValue(value);
headerFactories.add(builder.getBeanDefinition());

View File

@ -418,7 +418,7 @@
</section>
<section xml:id="nsa-header-ref">
<title><literal>header-ref</literal></title>
<para>Reference to a custom implementation of the <classname>HeaderFactory</classname> interface.</para>
<para>Reference to a custom implementation of the <classname>HeaderWriter</classname> interface.</para>
</section>
</section>
<section xml:id="nsa-header-parents">

View File

@ -1,23 +0,0 @@
package org.springframework.security.web.headers;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Contract for a factory that creates {@code Header} instances.
*
* @author Marten Deinum
* @since 3.2
* @see HeadersFilter
*/
public interface HeaderFactory {
/**
* Create a {@code Header} instance.
*
* @param request the request
* @param response the response
* @return the created Header or <code>null</code>
*/
Header create(HttpServletRequest request, HttpServletResponse response);
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.headers;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Contract for a factory that creates {@code Header} instances.
*
* @see HeadersFilter
*
* @author Marten Deinum
* @author Rob Winch
* @since 3.2
*/
public interface HeaderWriter {
/**
* Create a {@code Header} instance.
*
* @param request the request
* @param response the response
*/
void writeHeaders(HttpServletRequest request, HttpServletResponse response);
}

View File

@ -34,10 +34,10 @@ import java.util.*;
*/
public class HeadersFilter extends OncePerRequestFilter {
/** Collection of HeaderFactory instances to produce Headers. */
private final List<HeaderFactory> factories;
/** Collection of {@link HeaderWriter} instances to write out the headers to the response . */
private final List<HeaderWriter> factories;
public HeadersFilter(List<HeaderFactory> factories) {
public HeadersFilter(List<HeaderWriter> factories) {
this.factories = factories;
}
@ -45,28 +45,8 @@ public class HeadersFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
for (HeaderFactory factory : factories) {
Header header = factory.create(request, response);
if (header != null) {
String name = header.getName();
String[] values = header.getValues();
boolean first = true;
for (String value : values) {
if (logger.isDebugEnabled()) {
logger.debug("Adding header '" + name + "' with value '"+value +"'");
}
if (first) {
response.setHeader(name, value);
first = false;
} else {
response.addHeader(name, value);
}
}
} else {
if (logger.isDebugEnabled()) {
logger.debug("Factory produced no header.");
}
}
for (HeaderWriter factory : factories) {
factory.writeHeaders(request, response);
}
filterChain.doFilter(request, response);
}

View File

@ -6,23 +6,25 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.util.Assert;
/**
* {@code HeaderFactory} implementation which returns the same {@code Header} instance.
* {@code HeaderWriter} implementation which writes the same {@code Header} instance.
*
* @author Marten Deinum
* @since 3.2
*/
public class StaticHeaderFactory implements HeaderFactory {
public class StaticHeadersWriter implements HeaderWriter {
private final Header header;
public StaticHeaderFactory(String name, String... values) {
public StaticHeadersWriter(String name, String... values) {
Assert.hasText(name, "Header name is required");
Assert.notEmpty(values, "Header values cannot be null or empty");
Assert.noNullElements(values, "Header values cannot contain null values");
header = new Header(name, values);
}
public Header create(HttpServletRequest request, HttpServletResponse response) {
return header;
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
for(String value : header.getValues()) {
response.addHeader(header.getName(), value);
}
}
}

View File

@ -3,7 +3,7 @@ package org.springframework.security.web.headers.frameoptions;
import javax.servlet.http.HttpServletRequest;
/**
* Strategy interfaces used by the {@code FrameOptionsHeaderFactory} to determine the actual value to use for the
* Strategy interfaces used by the {@code FrameOptionsHeaderWriter} to determine the actual value to use for the
* X-Frame-Options header when using the ALLOW-FROM directive.
*
* @author Marten Deinum

View File

@ -1,13 +1,12 @@
package org.springframework.security.web.headers.frameoptions;
import org.springframework.security.web.headers.Header;
import org.springframework.security.web.headers.HeaderFactory;
import org.springframework.security.web.headers.HeaderWriter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* {@code HeaderFactory} implementation for the X-Frame-Options headers. When using the ALLOW-FROM directive the actual
* {@code HeaderWriter} implementation for the X-Frame-Options headers. When using the ALLOW-FROM directive the actual
* value is determined by a {@code AllowFromStrategy}.
*
* @author Marten Deinum
@ -15,7 +14,7 @@ import javax.servlet.http.HttpServletResponse;
*
* @see AllowFromStrategy
*/
public class FrameOptionsHeaderFactory implements HeaderFactory {
public class FrameOptionsHeaderWriter implements HeaderWriter {
public static final String FRAME_OPTIONS_HEADER = "X-Frame-Options";
@ -24,21 +23,21 @@ public class FrameOptionsHeaderFactory implements HeaderFactory {
private final AllowFromStrategy allowFromStrategy;
private final String mode;
public FrameOptionsHeaderFactory(String mode) {
public FrameOptionsHeaderWriter(String mode) {
this(mode, new NullAllowFromStrategy());
}
public FrameOptionsHeaderFactory(String mode, AllowFromStrategy allowFromStrategy) {
public FrameOptionsHeaderWriter(String mode, AllowFromStrategy allowFromStrategy) {
this.mode=mode;
this.allowFromStrategy=allowFromStrategy;
}
public Header create(HttpServletRequest request, HttpServletResponse response) {
public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
if (ALLOW_FROM.equals(mode)) {
String value = allowFromStrategy.apply(request);
return new Header(FRAME_OPTIONS_HEADER, ALLOW_FROM + " " + value);
response.addHeader(FRAME_OPTIONS_HEADER, ALLOW_FROM + " " + value);
} else {
return new Header(FRAME_OPTIONS_HEADER, mode);
response.addHeader(FRAME_OPTIONS_HEADER, mode);
}
}

View File

@ -15,32 +15,38 @@
*/
package org.springframework.security.web.headers;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.verify;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.*;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.matchers.JUnitMatchers.hasItems;
/**
* Tests for the {@code HeadersFilter}
*
* @author Marten Deinum
* @since 3.2
*/
@RunWith(MockitoJUnitRunner.class)
public class HeadersFilterTest {
@Mock
private HeaderWriter writer1;
@Mock
private HeaderWriter writer2;
@Test
public void noHeadersConfigured() throws Exception {
List<HeaderFactory> factories = new ArrayList();
HeadersFilter filter = new HeadersFilter(factories);
List<HeaderWriter> headerWriters = new ArrayList<HeaderWriter>();
HeadersFilter filter = new HeadersFilter(headerWriters);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();
@ -52,18 +58,11 @@ public class HeadersFilterTest {
@Test
public void additionalHeadersShouldBeAddedToTheResponse() throws Exception {
List<HeaderFactory> factories = new ArrayList();
MockHeaderFactory factory1 = new MockHeaderFactory();
factory1.setName("X-Header1");
factory1.setValue("foo");
MockHeaderFactory factory2 = new MockHeaderFactory();
factory2.setName("X-Header2");
factory2.setValue("bar");
List<HeaderWriter> headerWriters = new ArrayList<HeaderWriter>();
headerWriters.add(writer1);
headerWriters.add(writer2);
factories.add(factory1);
factories.add(factory2);
HeadersFilter filter = new HeadersFilter(factories);
HeadersFilter filter = new HeadersFilter(headerWriters);
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
@ -71,30 +70,7 @@ public class HeadersFilterTest {
filter.doFilter(request, response, filterChain);
Collection<String> headerNames = response.getHeaderNames();
assertThat(headerNames.size(), is(2));
assertThat(headerNames, hasItems("X-Header1", "X-Header2"));
assertThat(response.getHeader("X-Header1"), is("foo"));
assertThat(response.getHeader("X-Header2"), is("bar"));
}
private static final class MockHeaderFactory implements HeaderFactory {
private String name;
private String value;
public Header create(HttpServletRequest request, HttpServletResponse response) {
return new Header(name, value);
}
public void setName(String name) {
this.name=name;
}
public void setValue(String value) {
this.value=value;
}
verify(writer1).writeHeaders(request, response);
verify(writer2).writeHeaders(request, response);
}
}

View File

@ -1,26 +0,0 @@
package org.springframework.security.web.headers;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertSame;
import static org.springframework.test.util.MatcherAssertionErrors.assertThat;
/**
* Test for the {@code StaticHeaderFactory}
*
* @author Marten Deinum
* @since 3.2
*/
public class StaticHeaderFactoryTest {
@Test
public void sameHeaderShouldBeReturned() {
StaticHeaderFactory factory = new StaticHeaderFactory("X-header", "foo");
Header header = factory.create(null, null);
assertThat(header.getName(), is("X-header"));
assertThat(header.getValues()[0], is("foo"));
assertSame(header, factory.create(null, null));
}
}

View File

@ -0,0 +1,37 @@
package org.springframework.security.web.headers;
import static org.fest.assertions.Assertions.assertThat;
import java.util.Arrays;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
/**
* Test for the {@code StaticHeadersWriter}
*
* @author Marten Deinum
* @since 3.2
*/
public class StaticHeaderWriterTests {
private MockHttpServletRequest request;
private MockHttpServletResponse response;
@Before
public void setup() {
request = new MockHttpServletRequest();
response = new MockHttpServletResponse();
}
@Test
public void sameHeaderShouldBeReturned() {
String headerName = "X-header";
String headerValue = "foo";
StaticHeadersWriter factory = new StaticHeadersWriter(headerName, headerValue);
factory.writeHeaders(request, response);
assertThat(response.getHeaderValues(headerName)).isEqualTo(Arrays.asList(headerValue));
}
}