Add HttpsRedirectFilter

Closes gh-16678
This commit is contained in:
Josh Cummings 2025-02-28 09:07:15 -07:00
parent ec19efbf2a
commit 2d96fba5cf
No known key found for this signature in database
GPG Key ID: 869B37A20E876129
2 changed files with 281 additions and 0 deletions

View File

@ -0,0 +1,112 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.transport;
import java.io.IOException;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.PortMapper;
import org.springframework.security.web.PortMapperImpl;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
/**
* Redirects any non-HTTPS request to its HTTPS equivalent.
*
* <p>
* Can be configured to use a {@link RequestMatcher} to narrow which requests get
* redirected.
*
* <p>
* Can also be configured for custom ports using {@link PortMapper}.
*
* @author Josh Cummings
* @since 6.5
*/
public final class HttpsRedirectFilter extends OncePerRequestFilter {
private PortMapper portMapper = new PortMapperImpl();
private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE;
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
if (!isInsecure(request)) {
chain.doFilter(request, response);
return;
}
if (!this.requestMatcher.matches(request)) {
chain.doFilter(request, response);
return;
}
String redirectUri = createRedirectUri(request);
this.redirectStrategy.sendRedirect(request, response, redirectUri);
}
/**
* Use this {@link PortMapper} for mapping custom ports
* @param portMapper the {@link PortMapper} to use
*/
public void setPortMapper(PortMapper portMapper) {
Assert.notNull(portMapper, "portMapper cannot be null");
this.portMapper = portMapper;
}
/**
* Use this {@link RequestMatcher} to narrow which requests are redirected to HTTPS.
*
* The filter already first checks for HTTPS in the uri scheme, so it is not necessary
* to include that check in this matcher.
* @param requestMatcher the {@link RequestMatcher} to use
*/
public void setRequestMatcher(RequestMatcher requestMatcher) {
Assert.notNull(requestMatcher, "requestMatcher cannot be null");
this.requestMatcher = requestMatcher;
}
private boolean isInsecure(HttpServletRequest request) {
return !"https".equals(request.getScheme());
}
private String createRedirectUri(HttpServletRequest request) {
String url = UrlUtils.buildFullRequestUrl(request);
UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(url);
UriComponents components = builder.build();
int port = components.getPort();
if (port > 0) {
Integer httpsPort = this.portMapper.lookupHttpsPort(port);
Assert.state(httpsPort != null, () -> "HTTP Port '" + port + "' does not have a corresponding HTTPS Port");
builder.port(httpsPort);
}
return builder.scheme("https").toUriString();
}
}

View File

@ -0,0 +1,169 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.web.transport;
import jakarta.servlet.FilterChain;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.http.HttpHeaders;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.web.PortMapper;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link HttpsRedirectFilter}
*
* @author Josh Cummings
*/
@ExtendWith(MockitoExtension.class)
public class HttpsRedirectFilterTests {
HttpsRedirectFilter filter;
@Mock
FilterChain chain;
@BeforeEach
public void configureFilter() {
this.filter = new HttpsRedirectFilter();
}
@Test
public void filterWhenRequestIsInsecureThenRedirects() throws Exception {
HttpServletRequest request = get("http://localhost");
HttpServletResponse response = ok();
this.filter.doFilter(request, response, this.chain);
assertThat(statusCode(response)).isEqualTo(302);
assertThat(redirectedUrl(response)).isEqualTo("https://localhost");
}
@Test
public void filterWhenExchangeIsSecureThenNoRedirect() throws Exception {
HttpServletRequest request = get("https://localhost");
HttpServletResponse response = ok();
this.filter.doFilter(request, response, this.chain);
assertThat(statusCode(response)).isEqualTo(200);
}
@Test
public void filterWhenExchangeMismatchesThenNoRedirect() throws Exception {
RequestMatcher matcher = mock(RequestMatcher.class);
this.filter.setRequestMatcher(matcher);
HttpServletRequest request = get("http://localhost:8080");
HttpServletResponse response = ok();
this.filter.doFilter(request, response, this.chain);
assertThat(statusCode(response)).isEqualTo(200);
}
@Test
public void filterWhenExchangeMatchesAndRequestIsInsecureThenRedirects() throws Exception {
RequestMatcher matcher = mock(RequestMatcher.class);
given(matcher.matches(any())).willReturn(true);
this.filter.setRequestMatcher(matcher);
HttpServletRequest request = get("http://localhost:8080");
HttpServletResponse response = ok();
this.filter.doFilter(request, response, this.chain);
assertThat(statusCode(response)).isEqualTo(302);
assertThat(redirectedUrl(response)).isEqualTo("https://localhost:8443");
verify(matcher).matches(any(HttpServletRequest.class));
}
@Test
public void filterWhenRequestIsInsecureThenPortMapperRemapsPort() throws Exception {
PortMapper portMapper = mock(PortMapper.class);
given(portMapper.lookupHttpsPort(314)).willReturn(159);
this.filter.setPortMapper(portMapper);
HttpServletRequest request = get("http://localhost:314");
HttpServletResponse response = ok();
this.filter.doFilter(request, response, this.chain);
assertThat(statusCode(response)).isEqualTo(302);
assertThat(redirectedUrl(response)).isEqualTo("https://localhost:159");
verify(portMapper).lookupHttpsPort(314);
}
@Test
public void filterWhenRequestIsInsecureAndNoPortMappingThenThrowsIllegalState() {
HttpServletRequest request = get("http://localhost:1234");
HttpServletResponse response = ok();
assertThatIllegalStateException().isThrownBy(() -> this.filter.doFilter(request, response, this.chain));
}
@Test
public void filterWhenInsecureRequestHasAPathThenRedirects() throws Exception {
HttpServletRequest request = get("http://localhost:8080/path/page.html?query=string");
HttpServletResponse response = ok();
this.filter.doFilter(request, response, this.chain);
assertThat(statusCode(response)).isEqualTo(302);
assertThat(redirectedUrl(response)).isEqualTo("https://localhost:8443/path/page.html?query=string");
}
@Test
public void setRequiresTransportSecurityMatcherWhenSetWithNullValueThenThrowsIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestMatcher(null));
}
@Test
public void setPortMapperWhenSetWithNullValueThenThrowsIllegalArgument() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setPortMapper(null));
}
private String redirectedUrl(HttpServletResponse response) {
return response.getHeader(HttpHeaders.LOCATION);
}
private int statusCode(HttpServletResponse response) {
return response.getStatus();
}
private HttpServletRequest get(String uri) {
UriComponents components = UriComponentsBuilder.fromUriString(uri).build();
MockHttpServletRequest request = new MockHttpServletRequest("GET", components.getPath());
request.setQueryString(components.getQuery());
if (components.getScheme() != null) {
request.setScheme(components.getScheme());
}
int port = components.getPort();
if (port != -1) {
request.setServerPort(port);
}
return request;
}
private HttpServletResponse ok() {
MockHttpServletResponse response = new MockHttpServletResponse();
response.setStatus(200);
return response;
}
}