NIFI-14209 Restructured Host Header Validation (#9680)

- Replaced HostHeaderHandler with HostPortValidatorCustomizer
- Jetty SecureRequestCustomizer enforces host validation for SNI with Server Certificate DNS Subject Alternative Names
- Added tests for TLS SNI with invalid host and port values
- Refactored and streamlined RequestUriBuilder.fromHttpServletRequest()
This commit is contained in:
David Handermann 2025-02-05 13:43:07 -06:00 committed by GitHub
parent ea29da1cbf
commit ae5a77b84f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 316 additions and 514 deletions

View File

@ -59,17 +59,6 @@ public class RequestUriBuilder {
*/
public static RequestUriBuilder fromHttpServletRequest(final HttpServletRequest httpServletRequest) {
final List<String> allowedContextPaths = getAllowedContextPathsConfigured(httpServletRequest);
return fromHttpServletRequest(httpServletRequest, allowedContextPaths);
}
/**
* Return Builder from HTTP Servlet Request using Scheme, Host, Port, and Context Path reading from headers
*
* @param httpServletRequest HTTP Servlet Request
* @param allowedContextPaths Comma-separated string of allowed context path values for proxy headers
* @return Request URI Builder
*/
public static RequestUriBuilder fromHttpServletRequest(final HttpServletRequest httpServletRequest, final List<String> allowedContextPaths) {
final RequestUriProvider requestUriProvider = new StandardRequestUriProvider(allowedContextPaths);
final URI requestUri = requestUriProvider.getRequestUri(httpServletRequest);
return new RequestUriBuilder(requestUri.getScheme(), requestUri.getHost(), requestUri.getPort(), requestUri.getPath());

View File

@ -16,6 +16,7 @@
*/
package org.apache.nifi.web.servlet.shared;
import jakarta.servlet.ServletContext;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
@ -23,8 +24,6 @@ import org.mockito.junit.jupiter.MockitoExtension;
import jakarta.servlet.http.HttpServletRequest;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@ -48,16 +47,22 @@ public class RequestUriBuilderTest {
private static final String EMPTY = "";
private static final String ALLOWED_CONTEXT_PATHS = "allowedContextPaths";
@Mock
private HttpServletRequest httpServletRequest;
@Mock
private ServletContext servletContext;
@Test
public void testFromHttpServletRequestBuild() {
when(httpServletRequest.getServletContext()).thenReturn(servletContext);
when(httpServletRequest.getServerPort()).thenReturn(PORT);
when(httpServletRequest.getScheme()).thenReturn(SCHEME);
lenient().when(httpServletRequest.getHeader(eq(HOST_HEADER))).thenReturn(HOST);
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest, Collections.emptyList());
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest);
final URI uri = builder.build();
assertNotNull(uri);
@ -69,11 +74,12 @@ public class RequestUriBuilderTest {
@Test
public void testFromHttpServletRequestPathBuild() {
when(httpServletRequest.getServletContext()).thenReturn(servletContext);
when(httpServletRequest.getServerPort()).thenReturn(PORT);
when(httpServletRequest.getScheme()).thenReturn(SCHEME);
lenient().when(httpServletRequest.getHeader(eq(HOST_HEADER))).thenReturn(HOST);
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest, Collections.emptyList());
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest);
builder.fragment(FRAGMENT).path(CONTEXT_PATH);
final URI uri = builder.build();
@ -87,12 +93,14 @@ public class RequestUriBuilderTest {
@Test
public void testFromHttpServletRequestProxyHeadersBuild() {
when(httpServletRequest.getServletContext()).thenReturn(servletContext);
when(servletContext.getInitParameter(eq(ALLOWED_CONTEXT_PATHS))).thenReturn(CONTEXT_PATH);
when(httpServletRequest.getHeader(eq(ProxyHeader.PROXY_SCHEME.getHeader()))).thenReturn(SCHEME);
when(httpServletRequest.getHeader(eq(ProxyHeader.PROXY_HOST.getHeader()))).thenReturn(HOST);
when(httpServletRequest.getHeader(eq(ProxyHeader.PROXY_PORT.getHeader()))).thenReturn(Integer.toString(PORT));
when(httpServletRequest.getHeader(eq(ProxyHeader.PROXY_CONTEXT_PATH.getHeader()))).thenReturn(CONTEXT_PATH);
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest, List.of(CONTEXT_PATH));
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest);
final URI uri = builder.build();
assertNotNull(uri);

View File

@ -3444,6 +3444,7 @@ The value can be set to `h2` to require HTTP/2 and disable HTTP/1.1.
|`nifi.web.proxy.host`|A comma separated list of allowed HTTP Host header values to consider when NiFi is running securely and will be receiving requests to a different host[:port] than it is bound to.
For example, when running in a Docker container or behind a proxy (e.g. localhost:18443, proxyhost:443). By default, this value is blank meaning NiFi should only allow requests sent to the
host[:port] that NiFi is bound to.
Requests containing an invalid port in the Host or authority header return an HTTP 421 Misdirected Request status.
|`nifi.web.proxy.context.path`|A comma separated list of allowed HTTP X-ProxyContextPath, X-Forwarded-Context, or X-Forwarded-Prefix header values to consider. By default, this value is
blank meaning all requests containing a proxy context path are rejected. Configuring this property would allow requests where the proxy path is contained in this listing.
|`nifi.web.max.content.size`|The maximum size (HTTP `Content-Length`) for PUT and POST requests. No default value is set for backward compatibility. Providing a value for this property enables the `Content-Length` filter on all incoming API requests (except Site-to-Site and cluster communications). A suggested value is `20 MB`.

View File

@ -1,318 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.server;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.http.conn.util.InetAddressUtils;
import org.apache.nifi.util.NiFiProperties;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.util.Callback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.PrintWriter;
import java.net.HttpURLConnection;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
public class HostHeaderHandler extends Handler.Abstract {
private static final Logger logger = LoggerFactory.getLogger(HostHeaderHandler.class);
private final String serverName;
private final int serverPort;
private final List<String> validHosts;
/**
* Instantiates a handler which accepts incoming requests with a host header that is empty or contains one of the
* valid hosts. See the Apache NiFi Admin Guide for instructions on how to set valid hostnames and IP addresses.
*
* @param niFiProperties the NiFiProperties
*/
public HostHeaderHandler(final NiFiProperties niFiProperties) {
this.serverName = Objects.requireNonNull(determineServerHostname(niFiProperties));
this.serverPort = determineServerPort(niFiProperties);
// Default values across generic instances
List<String> hosts = generateDefaultHostnames(niFiProperties);
// The value from nifi.web.http|https.host
hosts.add(serverName.toLowerCase());
hosts.add(serverName.toLowerCase() + ":" + serverPort);
// The value(s) from nifi.web.proxy.host
hosts.addAll(parseCustomHostnames(niFiProperties));
// empty is ok here
hosts.add("");
this.validHosts = uniqueList(hosts);
logger.info("{} valid values for HTTP Request Host Header: {}", validHosts.size(), StringUtils.join(validHosts, ", "));
}
/**
* Returns the list of parsed custom hostnames from {@code nifi.web.proxy.host} in {@link NiFiProperties}.
* This list is deduplicated (if a host {@code somehost.com:1234} is provided, it will show twice, as the "portless"
* version {@code somehost.com} is also generated). IPv6 addresses are only modified if they adhere to the strict
* formatting using {@code []} around the address as specified in RFC 5952 Section 6 (i.e.
* {@code [1234.5678.90AB.CDEF.1234.5678.90AB.CDEF]:1234} will insert
* {@code [1234.5678.90AB.CDEF.1234.5678.90AB.CDEF]} as well).
*
* @param niFiProperties the properties object
* @return the list of parsed custom hostnames
*/
List<String> parseCustomHostnames(NiFiProperties niFiProperties) {
// Load the custom hostnames from the properties
List<String> customHostnames = niFiProperties.getAllowedHostsAsList();
/* Each IPv4 address and hostname may have the port associated, so duplicate the list and trim the port
* (the port may be different from the port NiFi is running on if provided by a proxy, etc.) IPv6 addresses
* are not modified.
*/
List<String> portlessHostnames = customHostnames.stream().map(hostname -> {
if (isIPv6Address(hostname)) {
return hostname;
} else {
return StringUtils.substringBeforeLast(hostname, ":");
}
}
).collect(Collectors.toList());
customHostnames.addAll(portlessHostnames);
if (logger.isDebugEnabled()) {
logger.debug("Parsed {} custom hostnames from nifi.web.proxy.host: {}", customHostnames.size(), StringUtils.join(customHostnames, ", "));
}
return uniqueList(customHostnames);
}
/**
* Returns a unique {@code List} of the elements maintaining the original order.
*
* @param duplicateList a list that may contain duplicate elements
* @return a list maintaining the original order which no longer contains duplicate elements
*/
private static List<String> uniqueList(List<String> duplicateList) {
return new ArrayList<>(new LinkedHashSet<>(duplicateList));
}
/**
* Returns true if the provided address is an IPv6 address (or could be interpreted as one). This method is more
* lenient than {@link InetAddressUtils#isIPv6Address(String)} because of different interpretations of IPv4-mapped
* IPv6 addresses.
* See RFC 5952 Section 4 for more information on textual representation of the IPv6 addresses.
*
* @param address the address in text form
* @return true if the address is or could be parsed as an IPv6 address
*/
static boolean isIPv6Address(String address) {
// Note: InetAddressUtils#isIPv4MappedIPv64Address() fails on addresses that do not compress the leading 0:0:0... to ::
// Expanded for debugging purposes
boolean isNormalIPv6 = InetAddressUtils.isIPv6Address(address);
// If the last two hextets are written in IPv4 form, treat it as an IPv6 address as well
String everythingAfterLastColon = StringUtils.substringAfterLast(address, ":");
boolean isIPv4 = InetAddressUtils.isIPv4Address(everythingAfterLastColon);
return isNormalIPv6 || isIPv4;
}
private int determineServerPort(NiFiProperties props) {
return props.getSslPort() != null ? props.getSslPort() : props.getPort();
}
private String determineServerHostname(NiFiProperties props) {
if (props.getSslPort() != null) {
return props.getProperty(NiFiProperties.WEB_HTTPS_HOST, "localhost");
} else {
return props.getProperty(NiFiProperties.WEB_HTTP_HOST, "localhost");
}
}
/**
* Host Header Valid status checks against valid hosts
*
* @param hostHeader Host header value
* @return Valid status
*/
boolean hostHeaderIsValid(final String hostHeader) {
return hostHeader != null && validHosts.contains(hostHeader.toLowerCase().trim());
}
@Override
public String toString() {
return "HostHeaderHandler for " + serverName + ":" + serverPort;
}
/**
* Returns an error message to the response and marks the request as handled if the host header is not valid.
* Otherwise passes the request on to the next scoped handler.
*
* @param request the request as an HttpServletRequest
* @param response the current response
*/
@Override
public boolean handle(final Request request, Response response, Callback callback) {
final String hostHeader = request.getHeaders().get(HttpHeader.HOST);
final String requestUri = request.getHttpURI().asString();
logger.debug("Request URI [{}] Host Header [{}]", requestUri, hostHeader);
if (!hostHeaderIsValid(hostHeader)) {
logger.warn("Request URI [{}] Host Header [{}] not valid", requestUri, hostHeader);
response.getHeaders().put(HttpHeader.CONTENT_TYPE, "text/html; charset=utf-8");
response.setStatus(HttpURLConnection.HTTP_OK);
try (PrintWriter out = Response.as(response, PrintWriter.class)) {
out.println("<h1>System Error</h1>");
out.println("<h2>The request contained an invalid host header [<code>" + StringEscapeUtils.escapeHtml4(hostHeader) +
"</code>] in the request [<code>" + StringEscapeUtils.escapeHtml4(request.getHttpURI().asString()) +
"</code>]. Check for request manipulation or third-party intercept.</h2>");
out.println("<h3>Valid host headers are [<code>empty</code>] or: <br/><code>");
out.println(printValidHosts());
out.println("</code></h3>");
}
return true;
} else {
return false;
}
}
String printValidHosts() {
StringBuilder sb = new StringBuilder("<ul>");
for (String vh : validHosts) {
if (StringUtils.isNotBlank(vh))
sb.append("<li>").append(StringEscapeUtils.escapeHtml4(vh)).append("</li>\n");
}
return sb.append("</ul>\n").toString();
}
public static List<String> generateDefaultHostnames(NiFiProperties niFiProperties) {
List<String> validHosts = new ArrayList<>();
int serverPort = 0;
if (niFiProperties == null) {
logger.warn("NiFiProperties not configured; returning minimal default hostnames");
} else {
try {
serverPort = niFiProperties.getConfiguredHttpOrHttpsPort();
} catch (RuntimeException e) {
logger.warn("Cannot fully generate list of default hostnames because the server port is not configured in nifi.properties. Defaulting to port 0 for host header evaluation");
}
// Add any custom network interfaces
try {
final int lambdaPort = serverPort;
List<String> customIPs = extractIPsFromNetworkInterfaces(niFiProperties);
customIPs.forEach(ip -> {
validHosts.add(ip);
validHosts.add(ip + ":" + lambdaPort);
});
} catch (final Exception e) {
logger.warn("Failed to determine custom network interfaces.", e);
}
}
// Sometimes the hostname is left empty but the port is always populated
validHosts.add("127.0.0.1");
validHosts.add("127.0.0.1:" + serverPort);
validHosts.add("localhost");
validHosts.add("localhost:" + serverPort);
validHosts.add("[::1]");
validHosts.add("[::1]:" + serverPort);
// Add the loopback and actual IP address and hostname used
try {
validHosts.add(InetAddress.getLoopbackAddress().getHostAddress().toLowerCase());
validHosts.add(InetAddress.getLoopbackAddress().getHostAddress().toLowerCase() + ":" + serverPort);
validHosts.add(InetAddress.getLocalHost().getHostName().toLowerCase());
validHosts.add(InetAddress.getLocalHost().getHostName().toLowerCase() + ":" + serverPort);
validHosts.add(InetAddress.getLocalHost().getHostAddress().toLowerCase());
validHosts.add(InetAddress.getLocalHost().getHostAddress().toLowerCase() + ":" + serverPort);
} catch (final Exception e) {
logger.warn("Failed to determine local hostname.", e);
}
// Dedupe but maintain order
final List<String> uniqueHosts = uniqueList(validHosts);
if (logger.isDebugEnabled()) {
logger.debug("Determined {} valid default hostnames and IP addresses for incoming headers: {}", uniqueHosts.size(), StringUtils.join(uniqueHosts, ", "));
}
return uniqueHosts;
}
/**
* Extracts the list of IP addresses from custom bound network interfaces. If both HTTPS and HTTP interfaces are
* defined and HTTPS is enabled, only HTTPS interfaces will be returned. If none are defined, an empty list will be
* returned.
*
* @param niFiProperties the NiFiProperties object
* @return the list of IP addresses
*/
static List<String> extractIPsFromNetworkInterfaces(NiFiProperties niFiProperties) {
Map<String, String> networkInterfaces = niFiProperties.isHTTPSConfigured() ? niFiProperties.getHttpsNetworkInterfaces() : niFiProperties.getHttpNetworkInterfaces();
if (isNotDefined(networkInterfaces)) {
// No custom interfaces defined
return List.of();
} else {
final List<String> allIPAddresses = new ArrayList<>();
for (Map.Entry<String, String> entry : networkInterfaces.entrySet()) {
final String networkInterfaceName = entry.getValue();
try {
final NetworkInterface ni = NetworkInterface.getByName(networkInterfaceName);
if (ni == null) {
logger.warn("Cannot resolve network interface named {}", networkInterfaceName);
} else {
final List<String> ipAddresses = Collections.list(ni.getInetAddresses()).stream().map(inetAddress -> inetAddress.getHostAddress().toLowerCase()).collect(Collectors.toList());
logger.debug("Resolved the following IP addresses for network interface {}: {}", networkInterfaceName, StringUtils.join(ipAddresses, ", "));
allIPAddresses.addAll(ipAddresses);
}
} catch (SocketException e) {
logger.warn("Cannot resolve network interface named {}", networkInterfaceName);
}
}
// Dedupe while maintaining order
return uniqueList(allIPAddresses);
}
}
/**
* Returns true if the provided map of properties and network interfaces is null, empty, or the actual definitions are empty.
*
* @param networkInterfaces the map of properties to bindings
* ({@code ["nifi.web.http.network.interface.first":"eth0"]})
* @return Not Defined status
*/
static boolean isNotDefined(Map<String, String> networkInterfaces) {
return networkInterfaces == null || networkInterfaces.isEmpty() || networkInterfaces.values().stream().filter(value -> StringUtils.isNotBlank(value)).collect(Collectors.toList()).isEmpty();
}
}

View File

@ -66,7 +66,7 @@ class StandardServerProvider implements ServerProvider {
final Server server = new Server(threadPool);
addConnectors(server, properties, sslContext);
final Handler standardHandler = getStandardHandler(properties);
final Handler standardHandler = getStandardHandler();
server.setHandler(standardHandler);
final RewriteHandler defaultRewriteHandler = new RewriteHandler();
@ -123,19 +123,13 @@ class StandardServerProvider implements ServerProvider {
}
}
private Handler getStandardHandler(final NiFiProperties properties) {
private Handler getStandardHandler() {
// Standard Handler supporting an ordered sequence of Handlers invoked until completion
final Handler.Collection standardHandler = new Handler.Sequence();
// Set Handler for standard response headers
standardHandler.addHandler(new HeaderWriterHandler());
// Validate Host Header when running with HTTPS enabled
if (properties.isHTTPSConfigured()) {
final HostHeaderHandler hostHeaderHandler = new HostHeaderHandler(properties);
standardHandler.addHandler(hostHeaderHandler);
}
return standardHandler;
}
}

View File

@ -22,7 +22,6 @@ import org.apache.nifi.jetty.configuration.connector.ApplicationLayerProtocol;
import org.apache.nifi.jetty.configuration.connector.StandardServerConnectorFactory;
import org.apache.nifi.processor.DataUnit;
import org.apache.nifi.security.util.TlsPlatform;
import org.apache.nifi.util.FormatUtils;
import org.apache.nifi.util.NiFiProperties;
import org.eclipse.jetty.server.HostHeaderCustomizer;
import org.eclipse.jetty.server.HttpConfiguration;
@ -30,28 +29,33 @@ import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* Framework extension of Server Connector Factory configures additional settings based on application properties
*/
public class FrameworkServerConnectorFactory extends StandardServerConnectorFactory {
private static final String DEFAULT_AUTO_REFRESH_INTERVAL = "30 s";
private static final int IDLE_TIMEOUT_MULTIPLIER = 2;
private static final int IDLE_TIMEOUT = 60000;
private static final String CIPHER_SUITE_SEPARATOR_PATTERN = ",\\s*";
private final int headerSize;
private static final Pattern HOST_PORT_PATTERN = Pattern.compile(".+?:(\\d+)$");
private final int idleTimeout;
private static final int PORT_GROUP = 1;
private final int headerSize;
private final String includeCipherSuites;
private final String excludeCipherSuites;
private final Set<Integer> validPorts;
private SslContextFactory.Server sslContextFactory;
/**
@ -66,7 +70,7 @@ public class FrameworkServerConnectorFactory extends StandardServerConnectorFact
includeCipherSuites = properties.getProperty(NiFiProperties.WEB_HTTPS_CIPHERSUITES_INCLUDE);
excludeCipherSuites = properties.getProperty(NiFiProperties.WEB_HTTPS_CIPHERSUITES_EXCLUDE);
headerSize = DataUnit.parseDataSize(properties.getWebMaxHeaderSize(), DataUnit.B).intValue();
idleTimeout = getIdleTimeout();
validPorts = getValidPorts(properties);
if (properties.isHTTPSConfigured()) {
if (properties.isClientAuthRequiredForRestApi()) {
@ -93,11 +97,14 @@ public class FrameworkServerConnectorFactory extends StandardServerConnectorFact
httpConfiguration.setRequestHeaderSize(headerSize);
httpConfiguration.setResponseHeaderSize(headerSize);
httpConfiguration.setIdleTimeout(idleTimeout);
httpConfiguration.setIdleTimeout(IDLE_TIMEOUT);
// Add HostHeaderCustomizer to set Host Header for HTTP/2 and HostHeaderHandler
httpConfiguration.addCustomizer(new HostHeaderCustomizer());
final HostPortValidatorCustomizer hostPortValidatorCustomizer = new HostPortValidatorCustomizer(validPorts);
httpConfiguration.addCustomizer(hostPortValidatorCustomizer);
return httpConfiguration;
}
@ -135,12 +142,6 @@ public class FrameworkServerConnectorFactory extends StandardServerConnectorFact
setApplicationLayerProtocols(applicationLayerProtocols);
}
private int getIdleTimeout() {
final String autoRefreshInterval = DEFAULT_AUTO_REFRESH_INTERVAL;
final double autoRefreshMilliseconds = FormatUtils.getPreciseTimeDuration(autoRefreshInterval, TimeUnit.MILLISECONDS);
return Math.multiplyExact((int) autoRefreshMilliseconds, IDLE_TIMEOUT_MULTIPLIER);
}
private String[] getCipherSuites(final String cipherSuitesProperty) {
return cipherSuitesProperty.split(CIPHER_SUITE_SEPARATOR_PATTERN);
}
@ -157,4 +158,23 @@ public class FrameworkServerConnectorFactory extends StandardServerConnectorFact
return ObjectUtils.defaultIfNull(httpsPort, httpPort);
}
private static Set<Integer> getValidPorts(final NiFiProperties properties) {
final Set<Integer> validPorts = new HashSet<>();
final int serverPort = getPort(properties);
validPorts.add(serverPort);
final List<String> allowedHosts = properties.getAllowedHostsAsList();
for (final String allowedHost : allowedHosts) {
final Matcher portMatcher = HOST_PORT_PATTERN.matcher(allowedHost);
if (portMatcher.matches()) {
final String portGroup = portMatcher.group(PORT_GROUP);
final int allowedPort = Integer.parseInt(portGroup);
validPorts.add(allowedPort);
}
}
return validPorts;
}
}

View File

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.server.connector;
import org.eclipse.jetty.http.BadMessageException;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.server.ConnectionMetaData;
import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.server.Request;
import java.net.InetSocketAddress;
import java.util.Objects;
import java.util.Set;
/**
* Jetty Request Customizer implementing validation of port included in HTTP/1.1 Host Header or HTTP/2 authority header
*/
public class HostPortValidatorCustomizer implements HttpConfiguration.Customizer {
private static final String MISDIRECTED_REQUEST_REASON = "Invalid Port Requested";
private static final int PORT_NOT_SPECIFIED = -1;
private final Set<Integer> validPorts;
/**
* HOst Port Validator Customer constructor with additional valid ports from application properties
*
* @param validPorts Valid Ports on HTTPS requests
*/
public HostPortValidatorCustomizer(final Set<Integer> validPorts) {
this.validPorts = Objects.requireNonNull(validPorts, "Valid Ports required");
}
/**
* Validate requested port against connected port and valid ports for secure HTTPS requests.
* The port is not specified when the header includes only the domain name as described in RFC 9110 Section 7.2.
* The port must match the local socket address port or a configured valid port number.
*
* @param request HTTP Request to be evaluated
* @param responseHeaders HTTP Response headers
* @return Valid HTTP Request
*/
@Override
public Request customize(final Request request, final HttpFields.Mutable responseHeaders) {
final Request customized;
if (request.isSecure()) {
final HttpURI requestUri = request.getHttpURI();
final int port = requestUri.getPort();
final int localSocketAddressPort = getLocalSocketAddressPort(request);
if (PORT_NOT_SPECIFIED == port || localSocketAddressPort == port || validPorts.contains(port)) {
customized = request;
} else {
throw new BadMessageException(HttpStatus.MISDIRECTED_REQUEST_421, MISDIRECTED_REQUEST_REASON);
}
} else {
customized = request;
}
return customized;
}
private int getLocalSocketAddressPort(final Request request) {
final ConnectionMetaData connectionMetaData = request.getConnectionMetaData();
final InetSocketAddress localSocketAddress = (InetSocketAddress) connectionMetaData.getLocalSocketAddress();
return localSocketAddress.getPort();
}
}

View File

@ -1,139 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.server;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.util.NiFiProperties;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class HostHeaderHandlerTest {
private static final String DEFAULT_HOSTNAME = "nifi.apache.org";
private static final int DEFAULT_PORT = 8080;
private static final List<String> IPV6_HOSTS = Arrays.asList("ABCD:EF01:2345:6789:ABCD:EF01:2345:6789",
"2001:DB8:0:0:8:800:200C:417A",
"FF01:0:0:0:0:0:0:101",
"0:0:0:0:0:0:0:1",
"0:0:0:0:0:0:0:0",
"2001:DB8::8:800:200C:417A",
"FF01::101",
"::1",
"::",
"0:0:0:0:0:0:13.1.68.3",
"0:0:0:0:0:FFFF:129.144.52.38",
"::13.1.68.3",
"FFFF:129.144.52.38",
"::FFFF:129.144.52.38");
private static List<String> defaultHostsAndPorts;
@BeforeAll
public static void setUpOnce() throws Exception {
String actualHostname = InetAddress.getLocalHost().getHostName().toLowerCase();
List<String> defaultHosts150 = Arrays.asList(DEFAULT_HOSTNAME, "localhost", actualHostname);
String actualIp = InetAddress.getLocalHost().getHostAddress();
String loopbackIp = InetAddress.getLoopbackAddress().getHostAddress();
List<String> defaultHosts = new ArrayList<>(defaultHosts150);
defaultHosts.remove(DEFAULT_HOSTNAME);
defaultHosts.addAll(Arrays.asList("[::1]", "127.0.0.1", actualIp, loopbackIp));
defaultHostsAndPorts = buildHostsWithPorts(defaultHosts, DEFAULT_PORT);
}
@Test
public void testNewConstructorShouldHandleCurrentDefaultValues() {
HostHeaderHandler handler = new HostHeaderHandler(getNifiProperties(null));
defaultHostsAndPorts.forEach(host -> assertTrue(handler.hostHeaderIsValid(host)));
}
@Test
public void testShouldParseCustomHostnames() {
List<String> otherHosts = Arrays.asList("someotherhost.com:9999", "yetanotherbadhost.com", "10.10.10.1:1234", "100.100.100.1");
NiFiProperties nifiProperties = getNifiProperties(otherHosts);
HostHeaderHandler handler = new HostHeaderHandler(nifiProperties);
final List<String> customHostnames = handler.parseCustomHostnames(nifiProperties);
assertEquals(otherHosts.size() + 2, customHostnames.size()); // Two provided hostnames had ports
otherHosts.forEach(host -> {
assertTrue(customHostnames.contains(host));
String portlessHost = host.split(":", 2)[0];
assertTrue(customHostnames.contains(portlessHost));
});
}
@Test
public void testParseCustomHostnamesShouldHandleIPv6WithoutPorts() {
NiFiProperties nifiProperties = getNifiProperties(IPV6_HOSTS);
HostHeaderHandler handler = new HostHeaderHandler(nifiProperties);
List<String> customHostnames = handler.parseCustomHostnames(nifiProperties);
assertEquals(IPV6_HOSTS.size(), customHostnames.size());
IPV6_HOSTS.forEach(host -> assertTrue(customHostnames.contains(host)));
}
@Test
public void testParseCustomHostnamesShouldHandleIPv6WithPorts() {
int port = 1234;
List<String> ipv6HostsWithPorts = buildHostsWithPorts(IPV6_HOSTS.stream()
.map(host -> "[" + host + "]")
.collect(Collectors.toList()), port);
NiFiProperties nifiProperties = getNifiProperties(ipv6HostsWithPorts);
HostHeaderHandler handler = new HostHeaderHandler(nifiProperties);
List<String> customHostnames = handler.parseCustomHostnames(nifiProperties);
assertEquals(ipv6HostsWithPorts.size() * 2, customHostnames.size());
ipv6HostsWithPorts.forEach(host -> {
assertTrue(customHostnames.contains(host));
String portlessHost = StringUtils.substringBeforeLast(host, ":");
assertTrue(customHostnames.contains(portlessHost));
}
);
}
@Test
public void testShouldIdentifyIPv6Addresses() {
IPV6_HOSTS.forEach(host -> assertTrue(HostHeaderHandler.isIPv6Address(host)));
}
private static List<String> buildHostsWithPorts(List<String> hosts, int port) {
return hosts.stream()
.map(host -> host + ":" + port)
.collect(Collectors.toList());
}
private NiFiProperties getNifiProperties(List<String> hosts) {
Properties bareboneProperties = new Properties();
bareboneProperties.put(NiFiProperties.WEB_HTTPS_HOST, DEFAULT_HOSTNAME);
bareboneProperties.put(NiFiProperties.WEB_HTTPS_PORT, Integer.toString(DEFAULT_PORT));
if (hosts != null) {
bareboneProperties.put(NiFiProperties.WEB_PROXY_HOST, String.join(",", hosts));
}
return new NiFiProperties(bareboneProperties);
}
}

View File

@ -17,24 +17,44 @@
package org.apache.nifi.web.server;
import org.apache.nifi.jetty.configuration.connector.ApplicationLayerProtocol;
import org.apache.nifi.security.cert.builder.StandardCertificateBuilder;
import org.apache.nifi.security.ssl.EphemeralKeyStoreBuilder;
import org.apache.nifi.security.ssl.StandardSslContextBuilder;
import org.apache.nifi.util.NiFiProperties;
import org.apache.nifi.web.server.handler.HeaderWriterHandler;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.rewrite.handler.RewriteHandler;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.RequestLog;
import org.eclipse.jetty.server.Server;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.springframework.web.util.UriComponentsBuilder;
import javax.net.ssl.SSLContext;
import java.security.NoSuchAlgorithmException;
import javax.security.auth.x500.X500Principal;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
class StandardServerProviderTest {
@ -42,6 +62,48 @@ class StandardServerProviderTest {
private static final String SSL_PROTOCOL = "ssl";
private static final Duration TIMEOUT = Duration.ofSeconds(15);
private static final String ALIAS = "entry-0";
private static final char[] PROTECTION_PARAMETER = new char[]{};
private static final String LOCALHOST_NAME = "localhost";
private static final X500Principal LOCALHOST_SUBJECT = new X500Principal("CN=%s, O=NiFi".formatted(LOCALHOST_NAME));
private static final String LOCALHOST_ADDRESS = "127.0.0.1";
private static final String LOCALHOST_HTTP_PORT = "localhost:80";
private static final String HOST_HEADER = "Host";
private static final String PUBLIC_HOST = "nifi.apache.org";
private static final String PUBLIC_UNKNOWN_HOST = "nifi.staged.apache.org";
private static final String ALLOW_RESTRICTED_HEADERS_PROPERTY = "jdk.httpclient.allowRestrictedHeaders";
private static SSLContext sslContext;
@BeforeAll
static void setConfiguration() throws Exception {
final KeyPair keyPair = KeyPairGenerator.getInstance("RSA").generateKeyPair();
final X509Certificate certificate = new StandardCertificateBuilder(keyPair, LOCALHOST_SUBJECT, Duration.ofHours(1))
.setDnsSubjectAlternativeNames(List.of(PUBLIC_HOST))
.build();
final KeyStore keyStore = new EphemeralKeyStoreBuilder().build();
keyStore.setKeyEntry(ALIAS, keyPair.getPrivate(), PROTECTION_PARAMETER, new Certificate[]{certificate});
sslContext = new StandardSslContextBuilder()
.keyStore(keyStore)
.trustStore(keyStore)
.keyPassword(PROTECTION_PARAMETER)
.build();
// Allow Restricted Headers for testing TLS SNI
System.setProperty(ALLOW_RESTRICTED_HEADERS_PROPERTY, HOST_HEADER);
}
@Test
void testGetServer() {
final Properties applicationProperties = new Properties();
@ -57,12 +119,11 @@ class StandardServerProviderTest {
}
@Test
void testGetServerHttps() throws NoSuchAlgorithmException {
void testGetServerHttps() {
final Properties applicationProperties = new Properties();
applicationProperties.setProperty(NiFiProperties.WEB_HTTPS_PORT, RANDOM_PORT);
final NiFiProperties properties = NiFiProperties.createBasicNiFiProperties(null, applicationProperties);
final SSLContext sslContext = SSLContext.getDefault();
final StandardServerProvider provider = new StandardServerProvider(sslContext);
final Server server = provider.getServer(properties);
@ -93,6 +154,93 @@ class StandardServerProviderTest {
}
}
@Timeout(15)
@Test
void testGetServerHttpsRequestsCompleted() throws Exception {
final Properties applicationProperties = new Properties();
applicationProperties.setProperty(NiFiProperties.WEB_HTTPS_PORT, RANDOM_PORT);
applicationProperties.setProperty(NiFiProperties.WEB_PROXY_HOST, PUBLIC_HOST);
final NiFiProperties properties = NiFiProperties.createBasicNiFiProperties(null, applicationProperties);
final StandardServerProvider provider = new StandardServerProvider(sslContext);
final Server server = provider.getServer(properties);
assertStandardConfigurationFound(server);
assertHttpsConnectorFound(server);
try {
server.start();
assertFalse(server.isFailed());
while (server.isStarting()) {
TimeUnit.MILLISECONDS.sleep(250);
}
assertTrue(server.isStarted());
final URI uri = server.getURI();
assertHttpsRequestsCompleted(uri);
} finally {
server.stop();
}
}
void assertHttpsRequestsCompleted(final URI serverUri) throws IOException, InterruptedException {
try (HttpClient httpClient = HttpClient.newBuilder()
.connectTimeout(TIMEOUT)
.sslContext(sslContext)
.build()
) {
final URI localhostUri = UriComponentsBuilder.fromUri(serverUri).host(LOCALHOST_NAME).build().toUri();
assertRedirectRequestsCompleted(httpClient, localhostUri);
assertBadRequestsCompleted(httpClient, localhostUri);
assertMisdirectedRequestsCompleted(httpClient, localhostUri);
}
}
void assertRedirectRequestsCompleted(final HttpClient httpClient, final URI localhostUri) throws IOException, InterruptedException {
final HttpRequest localhostRequest = HttpRequest.newBuilder(localhostUri)
.version(HttpClient.Version.HTTP_2)
.build();
assertResponseStatusCode(httpClient, localhostRequest, HttpStatus.MOVED_TEMPORARILY_302);
final HttpRequest alternativeNameRequest = HttpRequest.newBuilder(localhostUri)
.version(HttpClient.Version.HTTP_1_1)
.header(HOST_HEADER, PUBLIC_HOST)
.build();
assertResponseStatusCode(httpClient, alternativeNameRequest, HttpStatus.MOVED_TEMPORARILY_302);
}
void assertBadRequestsCompleted(final HttpClient httpClient, final URI localhostUri) throws IOException, InterruptedException {
final HttpRequest publicHostHeaderRequest = HttpRequest.newBuilder(localhostUri)
.header(HOST_HEADER, PUBLIC_UNKNOWN_HOST)
.version(HttpClient.Version.HTTP_1_1)
.build();
assertResponseStatusCode(httpClient, publicHostHeaderRequest, HttpStatus.BAD_REQUEST_400);
final HttpRequest localhostAddressRequest = HttpRequest.newBuilder(localhostUri)
.header(HOST_HEADER, LOCALHOST_ADDRESS)
.version(HttpClient.Version.HTTP_1_1)
.build();
assertResponseStatusCode(httpClient, localhostAddressRequest, HttpStatus.BAD_REQUEST_400);
}
void assertMisdirectedRequestsCompleted(final HttpClient httpClient, final URI localhostUri) throws IOException, InterruptedException {
final HttpRequest localhostPortRequest = HttpRequest.newBuilder(localhostUri)
.version(HttpClient.Version.HTTP_1_1)
.header(HOST_HEADER, LOCALHOST_HTTP_PORT)
.build();
assertResponseStatusCode(httpClient, localhostPortRequest, HttpStatus.MISDIRECTED_REQUEST_421);
}
void assertResponseStatusCode(final HttpClient httpClient, final HttpRequest request, final int statusCodeExpected) throws IOException, InterruptedException {
final HttpResponse<Void> response = httpClient.send(request, HttpResponse.BodyHandlers.discarding());
assertEquals(statusCodeExpected, response.statusCode());
}
void assertHttpConnectorFound(final Server server) {
final Connector[] connectors = server.getConnectors();
assertNotNull(connectors);

View File

@ -172,7 +172,7 @@ public abstract class ApplicationResource {
* @return the full external UI
*/
protected String generateExternalUiUri(final String... pathSegments) {
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest, properties.getAllowedContextPathsAsList());
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest);
final String path = String.join("/", pathSegments);
builder.path(path);
@ -186,7 +186,7 @@ public abstract class ApplicationResource {
}
private URI buildResourceUri(final URI uri) {
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest, properties.getAllowedContextPathsAsList());
final RequestUriBuilder builder = RequestUriBuilder.fromHttpServletRequest(httpServletRequest);
builder.path(uri.getPath());
return builder.build();
}

View File

@ -16,6 +16,7 @@
*/
package org.apache.nifi.web.api;
import jakarta.servlet.ServletContext;
import org.apache.nifi.util.NiFiProperties;
import org.apache.nifi.web.servlet.shared.ProxyHeader;
import org.glassfish.jersey.uri.internal.JerseyUriBuilder;
@ -38,6 +39,7 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.when;
@ -54,10 +56,14 @@ public class TestApplicationResource {
private static final String ACTUAL_RESOURCE = "actualResource";
private static final String EXPECTED_URI = BASE_URI + ":" + PORT + ALLOWED_PATH + FORWARD_SLASH + ACTUAL_RESOURCE;
private static final String MULTIPLE_ALLOWED_PATHS = String.join(",", ALLOWED_PATH, "another/path", "a/third/path");
private static final String ALLOWED_CONTEXT_PATHS = "allowedContextPaths";
@Mock
private HttpServletRequest request;
@Mock
private ServletContext servletContext;
private MockApplicationResource resource;
@BeforeEach
@ -69,6 +75,8 @@ public class TestApplicationResource {
when(request.getServerName()).thenReturn(HOST);
when(request.getServerPort()).thenReturn(PORT);
when(request.getServletContext()).thenReturn(servletContext);
resource = new MockApplicationResource();
resource.setHttpServletRequest(request);
resource.setUriInfo(uriInfo);
@ -156,6 +164,7 @@ public class TestApplicationResource {
private void setNiFiProperties(Map<String, String> props) {
resource.properties = new NiFiProperties(props);
when(servletContext.getInitParameter(eq(ALLOWED_CONTEXT_PATHS))).thenReturn(resource.properties.getAllowedContextPaths());
}
private static class MockApplicationResource extends ApplicationResource {

View File

@ -49,6 +49,7 @@ import java.net.URL;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
@ -62,6 +63,7 @@ public class TestDataTransferResource {
@BeforeAll
public static void setup() throws Exception {
final URL resource = TestDataTransferResource.class.getResource("/site-to-site/nifi.properties");
assertNotNull(resource);
final String propertiesFile = resource.toURI().getPath();
System.setProperty(NiFiProperties.PROPERTIES_FILE_PATH, propertiesFile);
}
@ -72,6 +74,8 @@ public class TestDataTransferResource {
doReturn(new StringBuffer("http://nifi.example.com:8080")
.append("/nifi-api/data-transfer/output-ports/port-id/transactions/tx-id/flow-files"))
.when(req).getRequestURL();
final ServletContext servletContext = mock(ServletContext.class);
when(req.getServletContext()).thenReturn(servletContext);
return req;
}
@ -174,6 +178,8 @@ public class TestDataTransferResource {
.getDeclaredField("httpServletRequest");
httpServletRequestField.setAccessible(true);
httpServletRequestField.set(resource, request);
final ServletContext servletContext = mock(ServletContext.class);
when(request.getServletContext()).thenReturn(servletContext);
final InputStream inputStream = null;
@ -209,6 +215,8 @@ public class TestDataTransferResource {
.getDeclaredField("httpServletRequest");
httpServletRequestField.setAccessible(true);
httpServletRequestField.set(resource, request);
final ServletContext servletContext = mock(ServletContext.class);
when(request.getServletContext()).thenReturn(servletContext);
final InputStream inputStream = null;

View File

@ -303,7 +303,7 @@ public class SamlAuthenticationSecurityConfiguration {
*/
@Bean
public RelyingPartyRegistrationResolver relyingPartyRegistrationResolver() {
return new StandardRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository(), properties.getAllowedContextPathsAsList());
return new StandardRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository());
}
/**

View File

@ -29,7 +29,6 @@ import org.springframework.web.util.UriComponentsBuilder;
import jakarta.servlet.http.HttpServletRequest;
import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -45,17 +44,13 @@ public class StandardRelyingPartyRegistrationResolver implements Converter<HttpS
private final RelyingPartyRegistrationRepository repository;
private final List<String> allowedContextPaths;
/**
* Standard Resolver with Registration Repository and Allowed Context Paths from application properties
*
* @param repository Relying Party Registration Repository required
* @param allowedContextPaths Allowed Context Paths required
*/
public StandardRelyingPartyRegistrationResolver(final RelyingPartyRegistrationRepository repository, final List<String> allowedContextPaths) {
public StandardRelyingPartyRegistrationResolver(final RelyingPartyRegistrationRepository repository) {
this.repository = Objects.requireNonNull(repository, "Repository required");
this.allowedContextPaths = Objects.requireNonNull(allowedContextPaths, "Allowed Context Paths required");
}
/**
@ -116,7 +111,7 @@ public class StandardRelyingPartyRegistrationResolver implements Converter<HttpS
}
private String getBaseUrl(final HttpServletRequest request) {
final URI requestUri = RequestUriBuilder.fromHttpServletRequest(request, allowedContextPaths).build();
final URI requestUri = RequestUriBuilder.fromHttpServletRequest(request).build();
final String httpUrl = requestUri.toString();
final String contextPath = request.getContextPath();
return UriComponentsBuilder.fromUriString(httpUrl).path(contextPath).replaceQuery(null).fragment(null).build().toString();

View File

@ -27,8 +27,6 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import java.util.Collections;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
@ -57,6 +55,8 @@ class StandardRelyingPartyRegistrationResolverTest {
private static final String REGISTRATION_ID = Saml2RegistrationProperty.REGISTRATION_ID.getProperty();
private static final String ALLOWED_CONTEXT_PATHS = "allowedContextPaths";
@Mock
RelyingPartyRegistrationRepository repository;
@ -73,7 +73,7 @@ class StandardRelyingPartyRegistrationResolverTest {
@Test
void testResolveNotFound() {
final StandardRelyingPartyRegistrationResolver resolver = new StandardRelyingPartyRegistrationResolver(repository, Collections.emptyList());
final StandardRelyingPartyRegistrationResolver resolver = new StandardRelyingPartyRegistrationResolver(repository);
final RelyingPartyRegistration registration = resolver.resolve(request, REGISTRATION_ID);
@ -82,7 +82,7 @@ class StandardRelyingPartyRegistrationResolverTest {
@Test
void testResolveFound() {
final StandardRelyingPartyRegistrationResolver resolver = new StandardRelyingPartyRegistrationResolver(repository, Collections.emptyList());
final StandardRelyingPartyRegistrationResolver resolver = new StandardRelyingPartyRegistrationResolver(repository);
final RelyingPartyRegistration registration = getRegistrationBuilder().build();
when(repository.findByRegistrationId(eq(REGISTRATION_ID))).thenReturn(registration);
@ -95,7 +95,9 @@ class StandardRelyingPartyRegistrationResolverTest {
@Test
void testResolveSingleLogoutForwardedPathFound() {
final StandardRelyingPartyRegistrationResolver resolver = new StandardRelyingPartyRegistrationResolver(repository, Collections.singletonList(FORWARDED_PATH));
request.getServletContext().setInitParameter(ALLOWED_CONTEXT_PATHS, FORWARDED_PATH);
final StandardRelyingPartyRegistrationResolver resolver = new StandardRelyingPartyRegistrationResolver(repository);
final RelyingPartyRegistration registration = getSingleLogoutRegistration();
when(repository.findByRegistrationId(eq(REGISTRATION_ID))).thenReturn(registration);