diff --git a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java index 9aa6868aca..282184b3b3 100644 --- a/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java +++ b/web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2020 the original author or authors. + * Copyright 2012-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -610,19 +610,25 @@ public class StrictHttpFirewall implements HttpFirewall { @Override public long getDateHeader(String name) { - validateAllowedHeaderName(name); + if (name != null) { + validateAllowedHeaderName(name); + } return super.getDateHeader(name); } @Override public int getIntHeader(String name) { - validateAllowedHeaderName(name); + if (name != null) { + validateAllowedHeaderName(name); + } return super.getIntHeader(name); } @Override public String getHeader(String name) { - validateAllowedHeaderName(name); + if (name != null) { + validateAllowedHeaderName(name); + } String value = super.getHeader(name); if (value != null) { validateAllowedHeaderValue(value); @@ -632,7 +638,9 @@ public class StrictHttpFirewall implements HttpFirewall { @Override public Enumeration getHeaders(String name) { - validateAllowedHeaderName(name); + if (name != null) { + validateAllowedHeaderName(name); + } Enumeration headers = super.getHeaders(name); return new Enumeration() { @@ -673,7 +681,9 @@ public class StrictHttpFirewall implements HttpFirewall { @Override public String getParameter(String name) { - validateAllowedParameterName(name); + if (name != null) { + validateAllowedParameterName(name); + } String value = super.getParameter(name); if (value != null) { validateAllowedParameterValue(value); @@ -717,7 +727,9 @@ public class StrictHttpFirewall implements HttpFirewall { @Override public String[] getParameterValues(String name) { - validateAllowedParameterName(name); + if (name != null) { + validateAllowedParameterName(name); + } String[] values = super.getParameterValues(name); if (values != null) { for (String value : values) { diff --git a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java index a553597e79..3ca722f9bd 100644 --- a/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java +++ b/web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2020 the original author or authors. + * Copyright 2012-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import org.junit.Test; import org.springframework.http.HttpMethod; import org.springframework.mock.web.MockHttpServletRequest; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** @@ -690,4 +691,45 @@ public class StrictHttpFirewallTests { .isThrownBy(() -> request.getParameterValues("bad name")); } + // gh-9598 + @Test + public void getFirewalledRequestGetParameterWhenNameIsNullThenIllegalArgumentException() { + HttpServletRequest request = this.firewall.getFirewalledRequest(this.request); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> request.getParameter(null)); + } + + // gh-9598 + @Test + public void getFirewalledRequestGetParameterValuesWhenNameIsNullThenIllegalArgumentException() { + HttpServletRequest request = this.firewall.getFirewalledRequest(this.request); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> request.getParameterValues(null)); + } + + // gh-9598 + @Test + public void getFirewalledRequestGetHeaderWhenNameIsNullThenNull() { + HttpServletRequest request = this.firewall.getFirewalledRequest(this.request); + assertThat(request.getHeader(null)).isNull(); + } + + // gh-9598 + @Test + public void getFirewalledRequestGetHeadersWhenNameIsNullThenEmptyEnumeration() { + HttpServletRequest request = this.firewall.getFirewalledRequest(this.request); + assertThat(request.getHeaders(null).hasMoreElements()).isFalse(); + } + + // gh-9598 + @Test + public void getFirewalledRequestGetIntHeaderWhenNameIsNullThenNegativeOne() { + HttpServletRequest request = this.firewall.getFirewalledRequest(this.request); + assertThat(request.getIntHeader(null)).isEqualTo(-1); + } + + @Test + public void getFirewalledRequestGetDateHeaderWhenNameIsNullThenNegativeOne() { + HttpServletRequest request = this.firewall.getFirewalledRequest(this.request); + assertThat(request.getDateHeader(null)).isEqualTo(-1); + } + }