NIFI-8218 This closes #4816. Use proxy headers when available when getting request values while processing SAML responses

Signed-off-by: Joe Witt <joewitt@apache.org>
This commit is contained in:
Bryan Bende 2021-02-10 09:39:27 -05:00 committed by Joe Witt
parent d5d520764d
commit 1d82fb8e01
No known key found for this signature in database
GPG Key ID: 9093BF854F811A1A
10 changed files with 696 additions and 187 deletions

View File

@ -16,18 +16,6 @@
*/ */
package org.apache.nifi.web.util; package org.apache.nifi.web.util;
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Stream;
import javax.net.ssl.SSLContext;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.core.UriBuilderException;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.http.conn.ssl.DefaultHostnameVerifier; import org.apache.http.conn.ssl.DefaultHostnameVerifier;
import org.glassfish.jersey.client.ClientConfig; import org.glassfish.jersey.client.ClientConfig;
@ -35,6 +23,19 @@ import org.glassfish.jersey.jackson.internal.jackson.jaxrs.json.JacksonJaxbJsonP
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLContext;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.core.UriBuilderException;
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Stream;
/** /**
* Common utilities related to web development. * Common utilities related to web development.
*/ */
@ -44,9 +45,17 @@ public final class WebUtils {
final static ReadWriteLock lock = new ReentrantReadWriteLock(); final static ReadWriteLock lock = new ReentrantReadWriteLock();
private static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath"; public static final String PROXY_SCHEME_HTTP_HEADER = "X-ProxyScheme";
private static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context"; public static final String PROXY_HOST_HTTP_HEADER = "X-ProxyHost";
private static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix"; public static final String PROXY_PORT_HTTP_HEADER = "X-ProxyPort";
public static final String FORWARDED_PROTO_HTTP_HEADER = "X-Forwarded-Proto";
public static final String FORWARDED_HOST_HTTP_HEADER = "X-Forwarded-Host";
public static final String FORWARDED_PORT_HTTP_HEADER = "X-Forwarded-Port";
public static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath";
public static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context";
public static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix";
private WebUtils() { private WebUtils() {
} }
@ -248,4 +257,118 @@ public final class WebUtils {
return false; return false;
} }
/**
* Returns the value for the first key discovered when inspecting the current request. Will
* return null if there are no keys specified or if none of the specified keys are found.
*
* @param httpServletRequest request
* @param keys http header keys
* @return the value for the first key found, or null if no matching keys found
*/
public static String getFirstHeaderValue(final HttpServletRequest httpServletRequest, final String... keys) {
if (keys == null) {
return null;
}
for (final String key : keys) {
final String value = httpServletRequest.getHeader(key);
// if we found an entry for this key, return the value
if (value != null) {
return value;
}
}
// unable to find any matching keys
return null;
}
/**
* Determines the scheme based on considering proxy related headers first and then falling back to the scheme of the servlet request.
*
* @param httpServletRequest the request
* @return the determined scheme
*/
public static String determineProxiedScheme(final HttpServletRequest httpServletRequest) {
final String schemeHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_SCHEME_HTTP_HEADER, FORWARDED_PROTO_HTTP_HEADER);
return StringUtils.isBlank(schemeHeaderValue) ? httpServletRequest.getScheme() : schemeHeaderValue;
}
/**
* Determines the host based on considering proxy related headers first and falling back to the host of the servlet request.
*
* @param httpServletRequest the request
* @return the determined host
*/
public static String determineProxiedHost(final HttpServletRequest httpServletRequest) {
final String hostHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER);
final String proxiedHost = determineProxiedHost(hostHeaderValue);
return StringUtils.isBlank(proxiedHost) ? httpServletRequest.getServerName() : proxiedHost;
}
/**
* Determines the host from the given header. The header value is intended to come from a header like X-ProxyHost or X-Forwarded-Host.
*
* @param hostHeaderValue the header value
* @return the determined host, or null if a host can't be determined
*/
public static String determineProxiedHost(final String hostHeaderValue) {
final String host;
// check for a port in the proxied host header
String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
if (hostSplits.length >= 1 && hostSplits.length <= 2) {
// zero or one occurrence of ':', this is an IPv4 address
// strip off the port by reassigning host the 0th split
host = hostSplits[0];
} else if (hostSplits.length == 0) {
// hostHeaderValue passed in was null, no splits
host = null;
} else {
// hostHeaderValue has more than one occurrence of ":", IPv6 address
host = hostHeaderValue;
}
return host;
}
/**
* Determines the port based on first considering proxy related headers and falling back to the port of the servlet request.
*
* @param httpServletRequest the request
* @return the determined port
*/
public static String determineProxiedPort(final HttpServletRequest httpServletRequest) {
final String hostHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER);
final String portHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_PORT_HTTP_HEADER, FORWARDED_PORT_HTTP_HEADER);
final String proxiedPort = determineProxiedPort(hostHeaderValue, portHeaderValue);
return StringUtils.isBlank(proxiedPort) ? String.valueOf(httpServletRequest.getServerPort()) : proxiedPort;
}
/**
* Determines the port based on the header values. The header values are intended to come from headers like X-ProxyHost/X-ProxyPort
* or X-Forwarded-Host/X-Forwarded-Port.
*
* @param hostHeaderValue the host header value
* @param portHeaderValue the host port value
* @return the determined port, or null if one can't be determined
*/
public static String determineProxiedPort(final String hostHeaderValue, final String portHeaderValue) {
final String port;
// check for a port in the proxied host header
String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
// determine the proxied port
final String portFromHostHeader;
if (hostSplits.length == 2) {
// if the port is specified in the proxied host header, it will be overridden by the
// port specified in X-ProxyPort or X-Forwarded-Port
portFromHostHeader = hostSplits[1];
} else {
portFromHostHeader = null;
}
if (StringUtils.isNotBlank(portFromHostHeader) && StringUtils.isNotBlank(portHeaderValue)) {
logger.warn(String.format("The proxied host header contained a port, but was overridden by the proxied port header"));
}
port = StringUtils.isNotBlank(portHeaderValue) ? portHeaderValue : (StringUtils.isNotBlank(portFromHostHeader) ? portFromHostHeader : null);
return port;
}
} }

View File

@ -39,8 +39,8 @@ import javax.ws.rs.core.UriBuilderException
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
class WebUtilsTest extends GroovyTestCase { class WebUtilsGroovyTest extends GroovyTestCase {
private static final Logger logger = LoggerFactory.getLogger(WebUtilsTest.class) private static final Logger logger = LoggerFactory.getLogger(WebUtilsGroovyTest.class)
static final String PCP_HEADER = "X-ProxyContextPath" static final String PCP_HEADER = "X-ProxyContextPath"
static final String FC_HEADER = "X-Forwarded-Context" static final String FC_HEADER = "X-Forwarded-Context"

View File

@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.util;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import javax.servlet.http.HttpServletRequest;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
import static org.junit.Assert.assertEquals;
@RunWith(MockitoJUnitRunner.class)
public class WebUtilsTest {
@Mock
private HttpServletRequest request;
// -- scheme tests
@Test
public void testDeterminedProxiedSchemeWhenNoHeaders() {
when(request.getHeader(any())).thenReturn(null);
when(request.getScheme()).thenReturn("https");
assertEquals("https", WebUtils.determineProxiedScheme(request));
}
@Test
public void testDeterminedProxiedSchemeWhenXProxySchemeAvailable() {
when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("http");
assertEquals("http", WebUtils.determineProxiedScheme(request));
}
@Test
public void testDeterminedProxiedSchemeWhenXForwardedProtoAvailable() {
when(request.getHeader(eq(WebUtils.FORWARDED_PROTO_HTTP_HEADER))).thenReturn("http");
assertEquals("http", WebUtils.determineProxiedScheme(request));
}
// -- host tests
@Test
public void testDetermineProxiedHostWhenNoHeaders() {
when(request.getHeader(any())).thenReturn(null);
when(request.getServerName()).thenReturn("localhost");
assertEquals("localhost", WebUtils.determineProxiedHost(request));
}
@Test
public void testDetermineProxiedHostWhenXProxyHostAvailable() {
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host");
assertEquals("x-proxy-host", WebUtils.determineProxiedHost(request));
}
@Test
public void testDetermineProxiedHostWhenXProxyHostAvailableWithPort() {
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:443");
assertEquals("x-proxy-host", WebUtils.determineProxiedHost(request));
}
@Test
public void testDetermineProxiedHostWhenXForwardedHostAvailable() {
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host");
assertEquals("x-forwarded-host", WebUtils.determineProxiedHost(request));
}
@Test
public void testDetermineProxiedHostWhenXForwardedHostAvailableWithPort() {
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:443");
assertEquals("x-forwarded-host", WebUtils.determineProxiedHost(request));
}
// -- port tests
@Test
public void testDetermineProxiedPortWhenNoHeaders() {
when(request.getServerPort()).thenReturn(443);
assertEquals("443", WebUtils.determineProxiedPort(request));
}
@Test
public void testDetermineProxiedPortWhenXProxyPortAvailable() {
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host");
when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("8443");
assertEquals("8443", WebUtils.determineProxiedPort(request));
}
@Test
public void testDetermineProxiedPortWhenPortInXProxyHost() {
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:1234");
assertEquals("1234", WebUtils.determineProxiedPort(request));
}
@Test
public void testDetermineProxiedPortWhenXProxyPortOverridesXProxy() {
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:1234");
when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("8443");
assertEquals("8443", WebUtils.determineProxiedPort(request));
}
@Test
public void testDetermineProxiedPortWhenXForwardedPortAvailable() {
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host");
when(request.getHeader(eq(WebUtils.FORWARDED_PORT_HTTP_HEADER))).thenReturn("8443");
assertEquals("8443", WebUtils.determineProxiedPort(request));
}
@Test
public void testDetermineProxiedPortWhenPortInXForwardedHost() {
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:1234");
assertEquals("1234", WebUtils.determineProxiedPort(request));
}
@Test
public void testDetermineProxiedPortWhenXForwardedPortOverridesXForwardedHost() {
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:1234");
when(request.getHeader(eq(WebUtils.FORWARDED_PORT_HTTP_HEADER))).thenReturn("8443");
assertEquals("8443", WebUtils.determineProxiedPort(request));
}
}

View File

@ -16,40 +16,8 @@
*/ */
package org.apache.nifi.web.api; package org.apache.nifi.web.api;
import static javax.ws.rs.core.Response.Status.NOT_FOUND;
import static org.apache.commons.lang3.StringUtils.isEmpty;
import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME;
import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE;
import com.google.common.cache.Cache; import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheBuilder;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.CacheControl;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.ResponseBuilder;
import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriBuilderException;
import javax.ws.rs.core.UriInfo;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.authorization.AuthorizableLookup; import org.apache.nifi.authorization.AuthorizableLookup;
import org.apache.nifi.authorization.AuthorizeAccess; import org.apache.nifi.authorization.AuthorizeAccess;
@ -92,6 +60,45 @@ import org.apache.nifi.web.util.WebUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.CacheControl;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.ResponseBuilder;
import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriBuilderException;
import javax.ws.rs.core.UriInfo;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import static javax.ws.rs.core.Response.Status.NOT_FOUND;
import static org.apache.commons.lang3.StringUtils.isEmpty;
import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME;
import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE;
import static org.apache.nifi.web.util.WebUtils.PROXY_SCHEME_HTTP_HEADER;
import static org.apache.nifi.web.util.WebUtils.PROXY_HOST_HTTP_HEADER;
import static org.apache.nifi.web.util.WebUtils.PROXY_PORT_HTTP_HEADER;
import static org.apache.nifi.web.util.WebUtils.FORWARDED_PROTO_HTTP_HEADER;
import static org.apache.nifi.web.util.WebUtils.FORWARDED_HOST_HTTP_HEADER;
import static org.apache.nifi.web.util.WebUtils.FORWARDED_PORT_HTTP_HEADER;
/** /**
* Base class for controllers. * Base class for controllers.
*/ */
@ -101,19 +108,6 @@ public abstract class ApplicationResource {
public static final String CLIENT_ID = "clientId"; public static final String CLIENT_ID = "clientId";
public static final String DISCONNECTED_NODE_ACKNOWLEDGED = "disconnectedNodeAcknowledged"; public static final String DISCONNECTED_NODE_ACKNOWLEDGED = "disconnectedNodeAcknowledged";
public static final String PROXY_SCHEME_HTTP_HEADER = "X-ProxyScheme";
public static final String PROXY_HOST_HTTP_HEADER = "X-ProxyHost";
public static final String PROXY_PORT_HTTP_HEADER = "X-ProxyPort";
public static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath";
public static final String FORWARDED_PROTO_HTTP_HEADER = "X-Forwarded-Proto";
public static final String FORWARDED_HOST_HTTP_HEADER = "X-Forwarded-Host";
public static final String FORWARDED_PORT_HTTP_HEADER = "X-Forwarded-Port";
public static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context";
// Traefik-specific headers
public static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix";
protected static final String NON_GUARANTEED_ENDPOINT = "Note: This endpoint is subject to change as NiFi and it's REST API evolve."; protected static final String NON_GUARANTEED_ENDPOINT = "Note: This endpoint is subject to change as NiFi and it's REST API evolve.";
private static final Logger logger = LoggerFactory.getLogger(ApplicationResource.class); private static final Logger logger = LoggerFactory.getLogger(ApplicationResource.class);
@ -157,8 +151,8 @@ public abstract class ApplicationResource {
final String hostHeaderValue = getFirstHeaderValue(PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER); final String hostHeaderValue = getFirstHeaderValue(PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER);
final String portHeaderValue = getFirstHeaderValue(PROXY_PORT_HTTP_HEADER, FORWARDED_PORT_HTTP_HEADER); final String portHeaderValue = getFirstHeaderValue(PROXY_PORT_HTTP_HEADER, FORWARDED_PORT_HTTP_HEADER);
final String host = determineProxiedHost(hostHeaderValue); final String host = WebUtils.determineProxiedHost(hostHeaderValue);
final String port = determineProxiedPort(hostHeaderValue, portHeaderValue); final String port = WebUtils.determineProxiedPort(hostHeaderValue, portHeaderValue);
// Catch header poisoning // Catch header poisoning
String allowedContextPaths = properties.getAllowedContextPaths(); String allowedContextPaths = properties.getAllowedContextPaths();
@ -194,44 +188,6 @@ public abstract class ApplicationResource {
return uri; return uri;
} }
private String determineProxiedHost(String hostHeaderValue) {
final String host;
// check for a port in the proxied host header
String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
if (hostSplits.length >= 1 && hostSplits.length <= 2) {
// zero or one occurrence of ':', this is an IPv4 address
// strip off the port by reassigning host the 0th split
host = hostSplits[0];
} else if (hostSplits.length == 0) {
// hostHeaderValue passed in was null, no splits
host = null;
} else {
// hostHeaderValue has more than one occurrence of ":", IPv6 address
host = hostHeaderValue;
}
return host;
}
private String determineProxiedPort(String hostHeaderValue, String portHeaderValue) {
final String port;
// check for a port in the proxied host header
String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
// determine the proxied port
final String portFromHostHeader;
if (hostSplits.length == 2) {
// if the port is specified in the proxied host header, it will be overridden by the
// port specified in X-ProxyPort or X-Forwarded-Port
portFromHostHeader = hostSplits[1];
} else {
portFromHostHeader = null;
}
if (StringUtils.isNotBlank(portFromHostHeader) && StringUtils.isNotBlank(portHeaderValue)) {
logger.warn(String.format("The proxied host header contained a port, but was overridden by the proxied port header"));
}
port = StringUtils.isNotBlank(portHeaderValue) ? portHeaderValue : (StringUtils.isNotBlank(portFromHostHeader) ? portFromHostHeader : null);
return port;
}
/** /**
* Edit the response headers to indicating no caching. * Edit the response headers to indicating no caching.
* *
@ -403,21 +359,7 @@ public abstract class ApplicationResource {
* @return the value for the first key found * @return the value for the first key found
*/ */
private String getFirstHeaderValue(final String... keys) { private String getFirstHeaderValue(final String... keys) {
if (keys == null) { return WebUtils.getFirstHeaderValue(httpServletRequest, keys);
return null;
}
for (final String key : keys) {
final String value = httpServletRequest.getHeader(key);
// if we found an entry for this key, return the value
if (value != null) {
return value;
}
}
// unable to find any matching keys
return null;
} }
/** /**

View File

@ -16,31 +16,6 @@
*/ */
package org.apache.nifi.web.api; package org.apache.nifi.web.api;
import static org.apache.nifi.web.api.ApplicationResource.PROXY_HOST_HTTP_HEADER;
import static org.apache.nifi.web.api.ApplicationResource.PROXY_PORT_HTTP_HEADER;
import static org.apache.nifi.web.api.ApplicationResource.PROXY_SCHEME_HTTP_HEADER;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.StreamingOutput;
import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriInfo;
import org.apache.nifi.authorization.AuthorizableLookup; import org.apache.nifi.authorization.AuthorizableLookup;
import org.apache.nifi.authorization.resource.ResourceType; import org.apache.nifi.authorization.resource.ResourceType;
import org.apache.nifi.remote.HttpRemoteSiteListener; import org.apache.nifi.remote.HttpRemoteSiteListener;
@ -58,6 +33,32 @@ import org.apache.nifi.web.api.entity.TransactionResultEntity;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.StreamingOutput;
import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriInfo;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import static org.apache.nifi.web.util.WebUtils.PROXY_HOST_HTTP_HEADER;
import static org.apache.nifi.web.util.WebUtils.PROXY_PORT_HTTP_HEADER;
import static org.apache.nifi.web.util.WebUtils.PROXY_SCHEME_HTTP_HEADER;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TestDataTransferResource { public class TestDataTransferResource {
@BeforeClass @BeforeClass

View File

@ -17,17 +17,19 @@
package org.apache.nifi.web.security.saml.impl; package org.apache.nifi.web.security.saml.impl;
import org.apache.nifi.web.security.saml.NiFiSAMLContextProvider; import org.apache.nifi.web.security.saml.NiFiSAMLContextProvider;
import org.apache.nifi.web.security.saml.impl.http.HttpServletRequestWithParameters;
import org.apache.nifi.web.security.saml.impl.http.ProxyAwareHttpServletRequestWrapper;
import org.opensaml.saml2.metadata.provider.MetadataProviderException; import org.opensaml.saml2.metadata.provider.MetadataProviderException;
import org.opensaml.ws.transport.http.HttpServletRequestAdapter; import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
import org.opensaml.ws.transport.http.HttpServletResponseAdapter; import org.opensaml.ws.transport.http.HttpServletResponseAdapter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.saml.context.SAMLContextProviderImpl; import org.springframework.security.saml.context.SAMLContextProviderImpl;
import org.springframework.security.saml.context.SAMLMessageContext; import org.springframework.security.saml.context.SAMLMessageContext;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map; import java.util.Map;
/** /**
@ -35,6 +37,8 @@ import java.util.Map;
*/ */
public class NiFiSAMLContextProviderImpl extends SAMLContextProviderImpl implements NiFiSAMLContextProvider { public class NiFiSAMLContextProviderImpl extends SAMLContextProviderImpl implements NiFiSAMLContextProvider {
private static final Logger LOGGER = LoggerFactory.getLogger(NiFiSAMLContextProviderImpl.class);
@Override @Override
public SAMLMessageContext getLocalEntity(HttpServletRequest request, HttpServletResponse response, Map<String, String> parameters) public SAMLMessageContext getLocalEntity(HttpServletRequest request, HttpServletResponse response, Map<String, String> parameters)
throws MetadataProviderException { throws MetadataProviderException {
@ -60,55 +64,20 @@ public class NiFiSAMLContextProviderImpl extends SAMLContextProviderImpl impleme
} }
protected void populateGenericContext(HttpServletRequest request, HttpServletResponse response, Map<String, String> parameters, SAMLMessageContext context) { protected void populateGenericContext(HttpServletRequest request, HttpServletResponse response, Map<String, String> parameters, SAMLMessageContext context) {
HttpServletRequestAdapter inTransport = new HttpServletRequestWithParameters(request, parameters); HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
HttpServletResponseAdapter outTransport = new HttpServletResponseAdapter(response, request.isSecure()); LOGGER.debug("Populating SAMLContext - request wrapper URL is [{}]", requestWrapper.getRequestURL().toString());
HttpServletRequestAdapter inTransport = new HttpServletRequestWithParameters(requestWrapper, parameters);
HttpServletResponseAdapter outTransport = new HttpServletResponseAdapter(response, requestWrapper.isSecure());
// Store attribute which cannot be located from InTransport directly // Store attribute which cannot be located from InTransport directly
request.setAttribute(org.springframework.security.saml.SAMLConstants.LOCAL_CONTEXT_PATH, request.getContextPath()); requestWrapper.setAttribute(org.springframework.security.saml.SAMLConstants.LOCAL_CONTEXT_PATH, requestWrapper.getContextPath());
context.setMetadataProvider(metadata); context.setMetadataProvider(metadata);
context.setInboundMessageTransport(inTransport); context.setInboundMessageTransport(inTransport);
context.setOutboundMessageTransport(outTransport); context.setOutboundMessageTransport(outTransport);
context.setMessageStorage(storageFactory.getMessageStorage(request)); context.setMessageStorage(storageFactory.getMessageStorage(requestWrapper));
} }
/**
* Extends the HttpServletRequestAdapter with a provided set of parameters.
*/
private static class HttpServletRequestWithParameters extends HttpServletRequestAdapter {
private final Map<String, String> providedParameters;
public HttpServletRequestWithParameters(HttpServletRequest request, Map<String,String> providedParameters) {
super(request);
this.providedParameters = providedParameters == null ? Collections.emptyMap() : providedParameters;
}
@Override
public String getParameterValue(String name) {
String value = super.getParameterValue(name);
if (value == null) {
value = providedParameters.get(name);
}
return value;
}
@Override
public List<String> getParameterValues(String name) {
List<String> combinedValues = new ArrayList<>();
List<String> initialValues = super.getParameterValues(name);
if (initialValues != null) {
combinedValues.addAll(initialValues);
}
String providedValue = providedParameters.get(name);
if (providedValue != null) {
combinedValues.add(providedValue);
}
return combinedValues;
}
}
} }

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.security.saml.impl.http;
import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* Extends the HttpServletRequestAdapter with a provided set of parameters.
*/
public class HttpServletRequestWithParameters extends HttpServletRequestAdapter {
private final Map<String, String> providedParameters;
public HttpServletRequestWithParameters(final HttpServletRequest request, final Map<String, String> providedParameters) {
super(request);
this.providedParameters = providedParameters == null ? Collections.emptyMap() : providedParameters;
}
@Override
public String getParameterValue(final String name) {
String value = super.getParameterValue(name);
if (value == null) {
value = providedParameters.get(name);
}
return value;
}
@Override
public List<String> getParameterValues(final String name) {
List<String> combinedValues = new ArrayList<>();
List<String> initialValues = super.getParameterValues(name);
if (initialValues != null) {
combinedValues.addAll(initialValues);
}
String providedValue = providedParameters.get(name);
if (providedValue != null) {
combinedValues.add(providedValue);
}
return combinedValues;
}
}

View File

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.security.saml.impl.http;
import org.apache.nifi.web.util.WebUtils;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
/**
* Extension of HttpServletRequestWrapper that respects proxied/forwarded header values for scheme, host, port, and context path.
* <p>
* If NiFi generates a SAML request using proxied values so that the IDP redirects back through the proxy, then this is needed
* so that when Open SAML checks the Destination in the SAML response, it will match with the values here.
* <p>
* This class is based on SAMLContextProviderLB from spring-security-saml.
*/
public class ProxyAwareHttpServletRequestWrapper extends HttpServletRequestWrapper {
private final String scheme;
private final String serverName;
private final int serverPort;
private final String proxyContextPath;
private final String contextPath;
public ProxyAwareHttpServletRequestWrapper(final HttpServletRequest request) {
super(request);
this.scheme = WebUtils.determineProxiedScheme(request);
this.serverName = WebUtils.determineProxiedHost(request);
this.serverPort = Integer.valueOf(WebUtils.determineProxiedPort(request));
final String tempProxyContextPath = WebUtils.normalizeContextPath(WebUtils.determineContextPath(request));
this.proxyContextPath = tempProxyContextPath.equals("/") ? "" : tempProxyContextPath;
this.contextPath = request.getContextPath();
}
@Override
public String getContextPath() {
return contextPath;
}
@Override
public String getScheme() {
return scheme;
}
@Override
public String getServerName() {
return serverName;
}
@Override
public int getServerPort() {
return serverPort;
}
@Override
public String getRequestURI() {
StringBuilder sb = new StringBuilder(contextPath);
sb.append(getServletPath());
return sb.toString();
}
@Override
public StringBuffer getRequestURL() {
StringBuffer sb = new StringBuffer();
sb.append(scheme).append("://").append(serverName);
sb.append(":").append(serverPort);
sb.append(proxyContextPath);
sb.append(contextPath);
sb.append(getServletPath());
if (getPathInfo() != null) sb.append(getPathInfo());
return sb;
}
@Override
public boolean isSecure() {
return "https".equalsIgnoreCase(scheme);
}
}

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.security.saml.impl.http;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
import javax.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
import static org.junit.Assert.assertEquals;
@RunWith(MockitoJUnitRunner.class)
public class TestHttpServletRequestWithParameters {
@Mock
private HttpServletRequest request;
@Test
public void testGetParameterValueWhenNoExtraParameters() {
final String paramName = "fooParam";
final String paramValue = "fooValue";
when(request.getParameter(eq(paramName))).thenReturn(paramValue);
final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, Collections.emptyMap());
final String result = requestAdapter.getParameterValue(paramName);
assertEquals(paramValue, result);
}
@Test
public void testGetParameterValueWhenExtraParameters() {
final String paramName = "fooParam";
final String paramValue = "fooValue";
final Map<String,String> extraParams = new HashMap<>();
extraParams.put(paramName, paramValue);
when(request.getParameter(any())).thenReturn(null);
final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, extraParams);
final String result = requestAdapter.getParameterValue(paramName);
assertEquals(paramValue, result);
}
@Test
public void testGetParameterValuesWhenNoExtraParameters() {
final String paramName = "fooParam";
final String paramValue = "fooValue";
when(request.getParameterValues(eq(paramName))).thenReturn(new String[] {paramValue});
final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, Collections.emptyMap());
final List<String> results = requestAdapter.getParameterValues(paramName);
assertEquals(1, results.size());
assertEquals(paramValue, results.get(0));
}
@Test
public void testGetParameterValuesWhenExtraParameters() {
final String paramName = "fooParam";
final String paramValue1 = "fooValue1";
when(request.getParameterValues(eq(paramName))).thenReturn(new String[] {paramValue1});
final String paramValue2 = "fooValue2";
final Map<String,String> extraParams = new HashMap<>();
extraParams.put(paramName, paramValue2);
final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, extraParams);
final List<String> results = requestAdapter.getParameterValues(paramName);
assertEquals(2, results.size());
assertTrue(results.contains(paramValue1));
assertTrue(results.contains(paramValue2));
}
}

View File

@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.security.saml.impl.http;
import org.apache.nifi.web.util.WebUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class TestProxyAwareHttpServletRequestWrapper {
@Mock
private HttpServletRequest request;
@Test
public void testWhenNotProxied() {
when(request.getScheme()).thenReturn("https");
when(request.getServerName()).thenReturn("localhost");
when(request.getServerPort()).thenReturn(8443);
when(request.getContextPath()).thenReturn("/nifi-api");
when(request.getServletPath()).thenReturn("/access/saml/metadata");
when(request.getHeader(any())).thenReturn(null);
final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
assertEquals("https://localhost:8443/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString());
}
@Test
public void testWhenProxied() {
when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("https");
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("proxy-host");
when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("443");
when(request.getHeader(eq(WebUtils.PROXY_CONTEXT_PATH_HTTP_HEADER))).thenReturn("/proxy-context");
when(request.getContextPath()).thenReturn("/nifi-api");
when(request.getServletPath()).thenReturn("/access/saml/metadata");
final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
assertEquals("https://proxy-host:443/proxy-context/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString());
}
@Test
public void testWhenProxiedWithEmptyProxyContextPath() {
when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("https");
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("proxy-host");
when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("443");
when(request.getHeader(eq(WebUtils.PROXY_CONTEXT_PATH_HTTP_HEADER))).thenReturn("/");
when(request.getContextPath()).thenReturn("/nifi-api");
when(request.getServletPath()).thenReturn("/access/saml/metadata");
final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
assertEquals("https://proxy-host:443/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString());
}
}