diff --git a/nifi-commons/nifi-web-utils/src/main/java/org/apache/nifi/web/util/WebUtils.java b/nifi-commons/nifi-web-utils/src/main/java/org/apache/nifi/web/util/WebUtils.java index 578206aadb..5e508e30d6 100644 --- a/nifi-commons/nifi-web-utils/src/main/java/org/apache/nifi/web/util/WebUtils.java +++ b/nifi-commons/nifi-web-utils/src/main/java/org/apache/nifi/web/util/WebUtils.java @@ -16,18 +16,6 @@ */ package org.apache.nifi.web.util; -import java.net.URI; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.stream.Stream; -import javax.net.ssl.SSLContext; -import javax.servlet.ServletRequest; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.client.Client; -import javax.ws.rs.client.ClientBuilder; -import javax.ws.rs.core.UriBuilderException; import org.apache.commons.lang3.StringUtils; import org.apache.http.conn.ssl.DefaultHostnameVerifier; import org.glassfish.jersey.client.ClientConfig; @@ -35,6 +23,19 @@ import org.glassfish.jersey.jackson.internal.jackson.jaxrs.json.JacksonJaxbJsonP import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.net.ssl.SSLContext; +import javax.servlet.ServletRequest; +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.client.Client; +import javax.ws.rs.client.ClientBuilder; +import javax.ws.rs.core.UriBuilderException; +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.stream.Stream; + /** * Common utilities related to web development. */ @@ -44,9 +45,17 @@ public final class WebUtils { final static ReadWriteLock lock = new ReentrantReadWriteLock(); - private static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath"; - private static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context"; - private static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix"; + public static final String PROXY_SCHEME_HTTP_HEADER = "X-ProxyScheme"; + public static final String PROXY_HOST_HTTP_HEADER = "X-ProxyHost"; + public static final String PROXY_PORT_HTTP_HEADER = "X-ProxyPort"; + + public static final String FORWARDED_PROTO_HTTP_HEADER = "X-Forwarded-Proto"; + public static final String FORWARDED_HOST_HTTP_HEADER = "X-Forwarded-Host"; + public static final String FORWARDED_PORT_HTTP_HEADER = "X-Forwarded-Port"; + + public static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath"; + public static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context"; + public static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix"; private WebUtils() { } @@ -248,4 +257,118 @@ public final class WebUtils { return false; } + /** + * Returns the value for the first key discovered when inspecting the current request. Will + * return null if there are no keys specified or if none of the specified keys are found. + * + * @param httpServletRequest request + * @param keys http header keys + * @return the value for the first key found, or null if no matching keys found + */ + public static String getFirstHeaderValue(final HttpServletRequest httpServletRequest, final String... keys) { + if (keys == null) { + return null; + } + + for (final String key : keys) { + final String value = httpServletRequest.getHeader(key); + + // if we found an entry for this key, return the value + if (value != null) { + return value; + } + } + + // unable to find any matching keys + return null; + } + + /** + * Determines the scheme based on considering proxy related headers first and then falling back to the scheme of the servlet request. + * + * @param httpServletRequest the request + * @return the determined scheme + */ + public static String determineProxiedScheme(final HttpServletRequest httpServletRequest) { + final String schemeHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_SCHEME_HTTP_HEADER, FORWARDED_PROTO_HTTP_HEADER); + return StringUtils.isBlank(schemeHeaderValue) ? httpServletRequest.getScheme() : schemeHeaderValue; + } + + /** + * Determines the host based on considering proxy related headers first and falling back to the host of the servlet request. + * + * @param httpServletRequest the request + * @return the determined host + */ + public static String determineProxiedHost(final HttpServletRequest httpServletRequest) { + final String hostHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER); + final String proxiedHost = determineProxiedHost(hostHeaderValue); + return StringUtils.isBlank(proxiedHost) ? httpServletRequest.getServerName() : proxiedHost; + } + + /** + * Determines the host from the given header. The header value is intended to come from a header like X-ProxyHost or X-Forwarded-Host. + * + * @param hostHeaderValue the header value + * @return the determined host, or null if a host can't be determined + */ + public static String determineProxiedHost(final String hostHeaderValue) { + final String host; + // check for a port in the proxied host header + String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":"); + if (hostSplits.length >= 1 && hostSplits.length <= 2) { + // zero or one occurrence of ':', this is an IPv4 address + // strip off the port by reassigning host the 0th split + host = hostSplits[0]; + } else if (hostSplits.length == 0) { + // hostHeaderValue passed in was null, no splits + host = null; + } else { + // hostHeaderValue has more than one occurrence of ":", IPv6 address + host = hostHeaderValue; + } + return host; + } + + /** + * Determines the port based on first considering proxy related headers and falling back to the port of the servlet request. + * + * @param httpServletRequest the request + * @return the determined port + */ + public static String determineProxiedPort(final HttpServletRequest httpServletRequest) { + final String hostHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER); + final String portHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_PORT_HTTP_HEADER, FORWARDED_PORT_HTTP_HEADER); + + final String proxiedPort = determineProxiedPort(hostHeaderValue, portHeaderValue); + return StringUtils.isBlank(proxiedPort) ? String.valueOf(httpServletRequest.getServerPort()) : proxiedPort; + } + + /** + * Determines the port based on the header values. The header values are intended to come from headers like X-ProxyHost/X-ProxyPort + * or X-Forwarded-Host/X-Forwarded-Port. + * + * @param hostHeaderValue the host header value + * @param portHeaderValue the host port value + * @return the determined port, or null if one can't be determined + */ + public static String determineProxiedPort(final String hostHeaderValue, final String portHeaderValue) { + final String port; + // check for a port in the proxied host header + String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":"); + // determine the proxied port + final String portFromHostHeader; + if (hostSplits.length == 2) { + // if the port is specified in the proxied host header, it will be overridden by the + // port specified in X-ProxyPort or X-Forwarded-Port + portFromHostHeader = hostSplits[1]; + } else { + portFromHostHeader = null; + } + if (StringUtils.isNotBlank(portFromHostHeader) && StringUtils.isNotBlank(portHeaderValue)) { + logger.warn(String.format("The proxied host header contained a port, but was overridden by the proxied port header")); + } + port = StringUtils.isNotBlank(portHeaderValue) ? portHeaderValue : (StringUtils.isNotBlank(portFromHostHeader) ? portFromHostHeader : null); + return port; + } } diff --git a/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.groovy b/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsGroovyTest.groovy similarity index 99% rename from nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.groovy rename to nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsGroovyTest.groovy index 5465c2cd91..c8eb7ffd26 100644 --- a/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.groovy +++ b/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsGroovyTest.groovy @@ -39,8 +39,8 @@ import javax.ws.rs.core.UriBuilderException import java.security.cert.X509Certificate @RunWith(JUnit4.class) -class WebUtilsTest extends GroovyTestCase { - private static final Logger logger = LoggerFactory.getLogger(WebUtilsTest.class) +class WebUtilsGroovyTest extends GroovyTestCase { + private static final Logger logger = LoggerFactory.getLogger(WebUtilsGroovyTest.class) static final String PCP_HEADER = "X-ProxyContextPath" static final String FC_HEADER = "X-Forwarded-Context" diff --git a/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.java b/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.java new file mode 100644 index 0000000000..38690a761d --- /dev/null +++ b/nifi-commons/nifi-web-utils/src/test/groovy/org/apache/nifi/web/util/WebUtilsTest.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.util; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import javax.servlet.http.HttpServletRequest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; +import static org.junit.Assert.assertEquals; + +@RunWith(MockitoJUnitRunner.class) +public class WebUtilsTest { + + @Mock + private HttpServletRequest request; + + // -- scheme tests + + @Test + public void testDeterminedProxiedSchemeWhenNoHeaders() { + when(request.getHeader(any())).thenReturn(null); + when(request.getScheme()).thenReturn("https"); + assertEquals("https", WebUtils.determineProxiedScheme(request)); + } + + @Test + public void testDeterminedProxiedSchemeWhenXProxySchemeAvailable() { + when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("http"); + assertEquals("http", WebUtils.determineProxiedScheme(request)); + } + + @Test + public void testDeterminedProxiedSchemeWhenXForwardedProtoAvailable() { + when(request.getHeader(eq(WebUtils.FORWARDED_PROTO_HTTP_HEADER))).thenReturn("http"); + assertEquals("http", WebUtils.determineProxiedScheme(request)); + } + + // -- host tests + + @Test + public void testDetermineProxiedHostWhenNoHeaders() { + when(request.getHeader(any())).thenReturn(null); + when(request.getServerName()).thenReturn("localhost"); + assertEquals("localhost", WebUtils.determineProxiedHost(request)); + } + + @Test + public void testDetermineProxiedHostWhenXProxyHostAvailable() { + when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host"); + assertEquals("x-proxy-host", WebUtils.determineProxiedHost(request)); + } + + @Test + public void testDetermineProxiedHostWhenXProxyHostAvailableWithPort() { + when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:443"); + assertEquals("x-proxy-host", WebUtils.determineProxiedHost(request)); + } + + @Test + public void testDetermineProxiedHostWhenXForwardedHostAvailable() { + when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host"); + assertEquals("x-forwarded-host", WebUtils.determineProxiedHost(request)); + } + + @Test + public void testDetermineProxiedHostWhenXForwardedHostAvailableWithPort() { + when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:443"); + assertEquals("x-forwarded-host", WebUtils.determineProxiedHost(request)); + } + + // -- port tests + + @Test + public void testDetermineProxiedPortWhenNoHeaders() { + when(request.getServerPort()).thenReturn(443); + assertEquals("443", WebUtils.determineProxiedPort(request)); + } + + @Test + public void testDetermineProxiedPortWhenXProxyPortAvailable() { + when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host"); + when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("8443"); + assertEquals("8443", WebUtils.determineProxiedPort(request)); + } + + @Test + public void testDetermineProxiedPortWhenPortInXProxyHost() { + when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:1234"); + assertEquals("1234", WebUtils.determineProxiedPort(request)); + } + + @Test + public void testDetermineProxiedPortWhenXProxyPortOverridesXProxy() { + when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:1234"); + when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("8443"); + assertEquals("8443", WebUtils.determineProxiedPort(request)); + } + + @Test + public void testDetermineProxiedPortWhenXForwardedPortAvailable() { + when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host"); + when(request.getHeader(eq(WebUtils.FORWARDED_PORT_HTTP_HEADER))).thenReturn("8443"); + assertEquals("8443", WebUtils.determineProxiedPort(request)); + } + + @Test + public void testDetermineProxiedPortWhenPortInXForwardedHost() { + when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:1234"); + assertEquals("1234", WebUtils.determineProxiedPort(request)); + } + + @Test + public void testDetermineProxiedPortWhenXForwardedPortOverridesXForwardedHost() { + when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:1234"); + when(request.getHeader(eq(WebUtils.FORWARDED_PORT_HTTP_HEADER))).thenReturn("8443"); + assertEquals("8443", WebUtils.determineProxiedPort(request)); + } + +} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/main/java/org/apache/nifi/web/api/ApplicationResource.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/main/java/org/apache/nifi/web/api/ApplicationResource.java index 196ba273a1..b24b9a38df 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/main/java/org/apache/nifi/web/api/ApplicationResource.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/main/java/org/apache/nifi/web/api/ApplicationResource.java @@ -16,40 +16,8 @@ */ package org.apache.nifi.web.api; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; -import static org.apache.commons.lang3.StringUtils.isEmpty; -import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME; -import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE; - import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.Enumeration; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.TreeMap; -import java.util.UUID; -import java.util.concurrent.TimeUnit; -import java.util.function.BiFunction; -import java.util.function.Consumer; -import java.util.function.Function; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.ws.rs.core.CacheControl; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.MultivaluedHashMap; -import javax.ws.rs.core.MultivaluedMap; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.Response.ResponseBuilder; -import javax.ws.rs.core.UriBuilder; -import javax.ws.rs.core.UriBuilderException; -import javax.ws.rs.core.UriInfo; import org.apache.commons.lang3.StringUtils; import org.apache.nifi.authorization.AuthorizableLookup; import org.apache.nifi.authorization.AuthorizeAccess; @@ -92,6 +60,45 @@ import org.apache.nifi.web.util.WebUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.core.CacheControl; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.MultivaluedHashMap; +import javax.ws.rs.core.MultivaluedMap; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.Response.ResponseBuilder; +import javax.ws.rs.core.UriBuilder; +import javax.ws.rs.core.UriBuilderException; +import javax.ws.rs.core.UriInfo; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +import static javax.ws.rs.core.Response.Status.NOT_FOUND; +import static org.apache.commons.lang3.StringUtils.isEmpty; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE; +import static org.apache.nifi.web.util.WebUtils.PROXY_SCHEME_HTTP_HEADER; +import static org.apache.nifi.web.util.WebUtils.PROXY_HOST_HTTP_HEADER; +import static org.apache.nifi.web.util.WebUtils.PROXY_PORT_HTTP_HEADER; +import static org.apache.nifi.web.util.WebUtils.FORWARDED_PROTO_HTTP_HEADER; +import static org.apache.nifi.web.util.WebUtils.FORWARDED_HOST_HTTP_HEADER; +import static org.apache.nifi.web.util.WebUtils.FORWARDED_PORT_HTTP_HEADER; + /** * Base class for controllers. */ @@ -101,19 +108,6 @@ public abstract class ApplicationResource { public static final String CLIENT_ID = "clientId"; public static final String DISCONNECTED_NODE_ACKNOWLEDGED = "disconnectedNodeAcknowledged"; - public static final String PROXY_SCHEME_HTTP_HEADER = "X-ProxyScheme"; - public static final String PROXY_HOST_HTTP_HEADER = "X-ProxyHost"; - public static final String PROXY_PORT_HTTP_HEADER = "X-ProxyPort"; - public static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath"; - - public static final String FORWARDED_PROTO_HTTP_HEADER = "X-Forwarded-Proto"; - public static final String FORWARDED_HOST_HTTP_HEADER = "X-Forwarded-Host"; - public static final String FORWARDED_PORT_HTTP_HEADER = "X-Forwarded-Port"; - public static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context"; - - // Traefik-specific headers - public static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix"; - protected static final String NON_GUARANTEED_ENDPOINT = "Note: This endpoint is subject to change as NiFi and it's REST API evolve."; private static final Logger logger = LoggerFactory.getLogger(ApplicationResource.class); @@ -157,8 +151,8 @@ public abstract class ApplicationResource { final String hostHeaderValue = getFirstHeaderValue(PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER); final String portHeaderValue = getFirstHeaderValue(PROXY_PORT_HTTP_HEADER, FORWARDED_PORT_HTTP_HEADER); - final String host = determineProxiedHost(hostHeaderValue); - final String port = determineProxiedPort(hostHeaderValue, portHeaderValue); + final String host = WebUtils.determineProxiedHost(hostHeaderValue); + final String port = WebUtils.determineProxiedPort(hostHeaderValue, portHeaderValue); // Catch header poisoning String allowedContextPaths = properties.getAllowedContextPaths(); @@ -194,44 +188,6 @@ public abstract class ApplicationResource { return uri; } - private String determineProxiedHost(String hostHeaderValue) { - final String host; - // check for a port in the proxied host header - String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":"); - if (hostSplits.length >= 1 && hostSplits.length <= 2) { - // zero or one occurrence of ':', this is an IPv4 address - // strip off the port by reassigning host the 0th split - host = hostSplits[0]; - } else if (hostSplits.length == 0) { - // hostHeaderValue passed in was null, no splits - host = null; - } else { - // hostHeaderValue has more than one occurrence of ":", IPv6 address - host = hostHeaderValue; - } - return host; - } - - private String determineProxiedPort(String hostHeaderValue, String portHeaderValue) { - final String port; - // check for a port in the proxied host header - String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":"); - // determine the proxied port - final String portFromHostHeader; - if (hostSplits.length == 2) { - // if the port is specified in the proxied host header, it will be overridden by the - // port specified in X-ProxyPort or X-Forwarded-Port - portFromHostHeader = hostSplits[1]; - } else { - portFromHostHeader = null; - } - if (StringUtils.isNotBlank(portFromHostHeader) && StringUtils.isNotBlank(portHeaderValue)) { - logger.warn(String.format("The proxied host header contained a port, but was overridden by the proxied port header")); - } - port = StringUtils.isNotBlank(portHeaderValue) ? portHeaderValue : (StringUtils.isNotBlank(portFromHostHeader) ? portFromHostHeader : null); - return port; - } - /** * Edit the response headers to indicating no caching. * @@ -403,21 +359,7 @@ public abstract class ApplicationResource { * @return the value for the first key found */ private String getFirstHeaderValue(final String... keys) { - if (keys == null) { - return null; - } - - for (final String key : keys) { - final String value = httpServletRequest.getHeader(key); - - // if we found an entry for this key, return the value - if (value != null) { - return value; - } - } - - // unable to find any matching keys - return null; + return WebUtils.getFirstHeaderValue(httpServletRequest, keys); } /** diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/test/java/org/apache/nifi/web/api/TestDataTransferResource.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/test/java/org/apache/nifi/web/api/TestDataTransferResource.java index cbe3030076..84573b0bd9 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/test/java/org/apache/nifi/web/api/TestDataTransferResource.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-api/src/test/java/org/apache/nifi/web/api/TestDataTransferResource.java @@ -16,31 +16,6 @@ */ package org.apache.nifi.web.api; -import static org.apache.nifi.web.api.ApplicationResource.PROXY_HOST_HTTP_HEADER; -import static org.apache.nifi.web.api.ApplicationResource.PROXY_PORT_HTTP_HEADER; -import static org.apache.nifi.web.api.ApplicationResource.PROXY_SCHEME_HTTP_HEADER; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import java.io.InputStream; -import java.lang.reflect.Field; -import java.net.URI; -import java.net.URISyntaxException; -import java.net.URL; -import javax.servlet.ServletContext; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.StreamingOutput; -import javax.ws.rs.core.UriBuilder; -import javax.ws.rs.core.UriInfo; import org.apache.nifi.authorization.AuthorizableLookup; import org.apache.nifi.authorization.resource.ResourceType; import org.apache.nifi.remote.HttpRemoteSiteListener; @@ -58,6 +33,32 @@ import org.apache.nifi.web.api.entity.TransactionResultEntity; import org.junit.BeforeClass; import org.junit.Test; +import javax.servlet.ServletContext; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.StreamingOutput; +import javax.ws.rs.core.UriBuilder; +import javax.ws.rs.core.UriInfo; +import java.io.InputStream; +import java.lang.reflect.Field; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; + +import static org.apache.nifi.web.util.WebUtils.PROXY_HOST_HTTP_HEADER; +import static org.apache.nifi.web.util.WebUtils.PROXY_PORT_HTTP_HEADER; +import static org.apache.nifi.web.util.WebUtils.PROXY_SCHEME_HTTP_HEADER; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class TestDataTransferResource { @BeforeClass diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/NiFiSAMLContextProviderImpl.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/NiFiSAMLContextProviderImpl.java index b85f176022..6ac659ae04 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/NiFiSAMLContextProviderImpl.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/NiFiSAMLContextProviderImpl.java @@ -17,17 +17,19 @@ package org.apache.nifi.web.security.saml.impl; import org.apache.nifi.web.security.saml.NiFiSAMLContextProvider; +import org.apache.nifi.web.security.saml.impl.http.HttpServletRequestWithParameters; +import org.apache.nifi.web.security.saml.impl.http.ProxyAwareHttpServletRequestWrapper; import org.opensaml.saml2.metadata.provider.MetadataProviderException; import org.opensaml.ws.transport.http.HttpServletRequestAdapter; import org.opensaml.ws.transport.http.HttpServletResponseAdapter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.security.saml.context.SAMLContextProviderImpl; import org.springframework.security.saml.context.SAMLMessageContext; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import java.util.Map; /** @@ -35,6 +37,8 @@ import java.util.Map; */ public class NiFiSAMLContextProviderImpl extends SAMLContextProviderImpl implements NiFiSAMLContextProvider { + private static final Logger LOGGER = LoggerFactory.getLogger(NiFiSAMLContextProviderImpl.class); + @Override public SAMLMessageContext getLocalEntity(HttpServletRequest request, HttpServletResponse response, Map parameters) throws MetadataProviderException { @@ -60,55 +64,20 @@ public class NiFiSAMLContextProviderImpl extends SAMLContextProviderImpl impleme } protected void populateGenericContext(HttpServletRequest request, HttpServletResponse response, Map parameters, SAMLMessageContext context) { - HttpServletRequestAdapter inTransport = new HttpServletRequestWithParameters(request, parameters); - HttpServletResponseAdapter outTransport = new HttpServletResponseAdapter(response, request.isSecure()); + HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request); + LOGGER.debug("Populating SAMLContext - request wrapper URL is [{}]", requestWrapper.getRequestURL().toString()); + + HttpServletRequestAdapter inTransport = new HttpServletRequestWithParameters(requestWrapper, parameters); + HttpServletResponseAdapter outTransport = new HttpServletResponseAdapter(response, requestWrapper.isSecure()); // Store attribute which cannot be located from InTransport directly - request.setAttribute(org.springframework.security.saml.SAMLConstants.LOCAL_CONTEXT_PATH, request.getContextPath()); + requestWrapper.setAttribute(org.springframework.security.saml.SAMLConstants.LOCAL_CONTEXT_PATH, requestWrapper.getContextPath()); context.setMetadataProvider(metadata); context.setInboundMessageTransport(inTransport); context.setOutboundMessageTransport(outTransport); - context.setMessageStorage(storageFactory.getMessageStorage(request)); + context.setMessageStorage(storageFactory.getMessageStorage(requestWrapper)); } - /** - * Extends the HttpServletRequestAdapter with a provided set of parameters. - */ - private static class HttpServletRequestWithParameters extends HttpServletRequestAdapter { - - private final Map providedParameters; - - public HttpServletRequestWithParameters(HttpServletRequest request, Map providedParameters) { - super(request); - this.providedParameters = providedParameters == null ? Collections.emptyMap() : providedParameters; - } - - @Override - public String getParameterValue(String name) { - String value = super.getParameterValue(name); - if (value == null) { - value = providedParameters.get(name); - } - return value; - } - - @Override - public List getParameterValues(String name) { - List combinedValues = new ArrayList<>(); - - List initialValues = super.getParameterValues(name); - if (initialValues != null) { - combinedValues.addAll(initialValues); - } - - String providedValue = providedParameters.get(name); - if (providedValue != null) { - combinedValues.add(providedValue); - } - - return combinedValues; - } - } } diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/HttpServletRequestWithParameters.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/HttpServletRequestWithParameters.java new file mode 100644 index 0000000000..0c915973c2 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/HttpServletRequestWithParameters.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.security.saml.impl.http; + +import org.opensaml.ws.transport.http.HttpServletRequestAdapter; + +import javax.servlet.http.HttpServletRequest; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * Extends the HttpServletRequestAdapter with a provided set of parameters. + */ +public class HttpServletRequestWithParameters extends HttpServletRequestAdapter { + + private final Map providedParameters; + + public HttpServletRequestWithParameters(final HttpServletRequest request, final Map providedParameters) { + super(request); + this.providedParameters = providedParameters == null ? Collections.emptyMap() : providedParameters; + } + + @Override + public String getParameterValue(final String name) { + String value = super.getParameterValue(name); + if (value == null) { + value = providedParameters.get(name); + } + return value; + } + + @Override + public List getParameterValues(final String name) { + List combinedValues = new ArrayList<>(); + + List initialValues = super.getParameterValues(name); + if (initialValues != null) { + combinedValues.addAll(initialValues); + } + + String providedValue = providedParameters.get(name); + if (providedValue != null) { + combinedValues.add(providedValue); + } + + return combinedValues; + } +} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/ProxyAwareHttpServletRequestWrapper.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/ProxyAwareHttpServletRequestWrapper.java new file mode 100644 index 0000000000..09121ec991 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/main/java/org/apache/nifi/web/security/saml/impl/http/ProxyAwareHttpServletRequestWrapper.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.security.saml.impl.http; + +import org.apache.nifi.web.util.WebUtils; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; + +/** + * Extension of HttpServletRequestWrapper that respects proxied/forwarded header values for scheme, host, port, and context path. + *

+ * If NiFi generates a SAML request using proxied values so that the IDP redirects back through the proxy, then this is needed + * so that when Open SAML checks the Destination in the SAML response, it will match with the values here. + *

+ * This class is based on SAMLContextProviderLB from spring-security-saml. + */ +public class ProxyAwareHttpServletRequestWrapper extends HttpServletRequestWrapper { + + private final String scheme; + private final String serverName; + private final int serverPort; + private final String proxyContextPath; + private final String contextPath; + + public ProxyAwareHttpServletRequestWrapper(final HttpServletRequest request) { + super(request); + this.scheme = WebUtils.determineProxiedScheme(request); + this.serverName = WebUtils.determineProxiedHost(request); + this.serverPort = Integer.valueOf(WebUtils.determineProxiedPort(request)); + + final String tempProxyContextPath = WebUtils.normalizeContextPath(WebUtils.determineContextPath(request)); + this.proxyContextPath = tempProxyContextPath.equals("/") ? "" : tempProxyContextPath; + + this.contextPath = request.getContextPath(); + } + + @Override + public String getContextPath() { + return contextPath; + } + + @Override + public String getScheme() { + return scheme; + } + + @Override + public String getServerName() { + return serverName; + } + + @Override + public int getServerPort() { + return serverPort; + } + + @Override + public String getRequestURI() { + StringBuilder sb = new StringBuilder(contextPath); + sb.append(getServletPath()); + return sb.toString(); + } + + @Override + public StringBuffer getRequestURL() { + StringBuffer sb = new StringBuffer(); + sb.append(scheme).append("://").append(serverName); + sb.append(":").append(serverPort); + sb.append(proxyContextPath); + sb.append(contextPath); + sb.append(getServletPath()); + if (getPathInfo() != null) sb.append(getPathInfo()); + return sb; + } + + @Override + public boolean isSecure() { + return "https".equalsIgnoreCase(scheme); + } + +} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestHttpServletRequestWithParameters.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestHttpServletRequestWithParameters.java new file mode 100644 index 0000000000..694dea20d3 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestHttpServletRequestWithParameters.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.security.saml.impl.http; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensaml.ws.transport.http.HttpServletRequestAdapter; + +import javax.servlet.http.HttpServletRequest; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; +import static org.junit.Assert.assertEquals; + +@RunWith(MockitoJUnitRunner.class) +public class TestHttpServletRequestWithParameters { + + @Mock + private HttpServletRequest request; + + @Test + public void testGetParameterValueWhenNoExtraParameters() { + final String paramName = "fooParam"; + final String paramValue = "fooValue"; + when(request.getParameter(eq(paramName))).thenReturn(paramValue); + + final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, Collections.emptyMap()); + final String result = requestAdapter.getParameterValue(paramName); + assertEquals(paramValue, result); + } + + @Test + public void testGetParameterValueWhenExtraParameters() { + final String paramName = "fooParam"; + final String paramValue = "fooValue"; + + final Map extraParams = new HashMap<>(); + extraParams.put(paramName, paramValue); + + when(request.getParameter(any())).thenReturn(null); + + final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, extraParams); + final String result = requestAdapter.getParameterValue(paramName); + assertEquals(paramValue, result); + } + + @Test + public void testGetParameterValuesWhenNoExtraParameters() { + final String paramName = "fooParam"; + final String paramValue = "fooValue"; + when(request.getParameterValues(eq(paramName))).thenReturn(new String[] {paramValue}); + + final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, Collections.emptyMap()); + final List results = requestAdapter.getParameterValues(paramName); + assertEquals(1, results.size()); + assertEquals(paramValue, results.get(0)); + } + + @Test + public void testGetParameterValuesWhenExtraParameters() { + final String paramName = "fooParam"; + final String paramValue1 = "fooValue1"; + when(request.getParameterValues(eq(paramName))).thenReturn(new String[] {paramValue1}); + + final String paramValue2 = "fooValue2"; + final Map extraParams = new HashMap<>(); + extraParams.put(paramName, paramValue2); + + final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, extraParams); + final List results = requestAdapter.getParameterValues(paramName); + assertEquals(2, results.size()); + assertTrue(results.contains(paramValue1)); + assertTrue(results.contains(paramValue2)); + } +} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestProxyAwareHttpServletRequestWrapper.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestProxyAwareHttpServletRequestWrapper.java new file mode 100644 index 0000000000..e303c7e770 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/saml/impl/http/TestProxyAwareHttpServletRequestWrapper.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.security.saml.impl.http; + +import org.apache.nifi.web.util.WebUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class TestProxyAwareHttpServletRequestWrapper { + + @Mock + private HttpServletRequest request; + + @Test + public void testWhenNotProxied() { + when(request.getScheme()).thenReturn("https"); + when(request.getServerName()).thenReturn("localhost"); + when(request.getServerPort()).thenReturn(8443); + when(request.getContextPath()).thenReturn("/nifi-api"); + when(request.getServletPath()).thenReturn("/access/saml/metadata"); + when(request.getHeader(any())).thenReturn(null); + + final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request); + assertEquals("https://localhost:8443/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString()); + } + + @Test + public void testWhenProxied() { + when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("https"); + when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("proxy-host"); + when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("443"); + when(request.getHeader(eq(WebUtils.PROXY_CONTEXT_PATH_HTTP_HEADER))).thenReturn("/proxy-context"); + when(request.getContextPath()).thenReturn("/nifi-api"); + when(request.getServletPath()).thenReturn("/access/saml/metadata"); + + final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request); + assertEquals("https://proxy-host:443/proxy-context/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString()); + } + + @Test + public void testWhenProxiedWithEmptyProxyContextPath() { + when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("https"); + when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("proxy-host"); + when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("443"); + when(request.getHeader(eq(WebUtils.PROXY_CONTEXT_PATH_HTTP_HEADER))).thenReturn("/"); + when(request.getContextPath()).thenReturn("/nifi-api"); + when(request.getServletPath()).thenReturn("/access/saml/metadata"); + + final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request); + assertEquals("https://proxy-host:443/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString()); + } +}