mirror of https://github.com/apache/nifi.git
NIFI-8218 This closes #4816. Use proxy headers when available when getting request values while processing SAML responses
Signed-off-by: Joe Witt <joewitt@apache.org>
This commit is contained in:
parent
d5d520764d
commit
1d82fb8e01
|
@ -16,18 +16,6 @@
|
|||
*/
|
||||
package org.apache.nifi.web.util;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.locks.ReadWriteLock;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
import java.util.stream.Stream;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.servlet.ServletRequest;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.ws.rs.client.Client;
|
||||
import javax.ws.rs.client.ClientBuilder;
|
||||
import javax.ws.rs.core.UriBuilderException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.http.conn.ssl.DefaultHostnameVerifier;
|
||||
import org.glassfish.jersey.client.ClientConfig;
|
||||
|
@ -35,6 +23,19 @@ import org.glassfish.jersey.jackson.internal.jackson.jaxrs.json.JacksonJaxbJsonP
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.servlet.ServletRequest;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.ws.rs.client.Client;
|
||||
import javax.ws.rs.client.ClientBuilder;
|
||||
import javax.ws.rs.core.UriBuilderException;
|
||||
import java.net.URI;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.locks.ReadWriteLock;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* Common utilities related to web development.
|
||||
*/
|
||||
|
@ -44,9 +45,17 @@ public final class WebUtils {
|
|||
|
||||
final static ReadWriteLock lock = new ReentrantReadWriteLock();
|
||||
|
||||
private static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath";
|
||||
private static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context";
|
||||
private static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix";
|
||||
public static final String PROXY_SCHEME_HTTP_HEADER = "X-ProxyScheme";
|
||||
public static final String PROXY_HOST_HTTP_HEADER = "X-ProxyHost";
|
||||
public static final String PROXY_PORT_HTTP_HEADER = "X-ProxyPort";
|
||||
|
||||
public static final String FORWARDED_PROTO_HTTP_HEADER = "X-Forwarded-Proto";
|
||||
public static final String FORWARDED_HOST_HTTP_HEADER = "X-Forwarded-Host";
|
||||
public static final String FORWARDED_PORT_HTTP_HEADER = "X-Forwarded-Port";
|
||||
|
||||
public static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath";
|
||||
public static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context";
|
||||
public static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix";
|
||||
|
||||
private WebUtils() {
|
||||
}
|
||||
|
@ -248,4 +257,118 @@ public final class WebUtils {
|
|||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the value for the first key discovered when inspecting the current request. Will
|
||||
* return null if there are no keys specified or if none of the specified keys are found.
|
||||
*
|
||||
* @param httpServletRequest request
|
||||
* @param keys http header keys
|
||||
* @return the value for the first key found, or null if no matching keys found
|
||||
*/
|
||||
public static String getFirstHeaderValue(final HttpServletRequest httpServletRequest, final String... keys) {
|
||||
if (keys == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
for (final String key : keys) {
|
||||
final String value = httpServletRequest.getHeader(key);
|
||||
|
||||
// if we found an entry for this key, return the value
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
// unable to find any matching keys
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the scheme based on considering proxy related headers first and then falling back to the scheme of the servlet request.
|
||||
*
|
||||
* @param httpServletRequest the request
|
||||
* @return the determined scheme
|
||||
*/
|
||||
public static String determineProxiedScheme(final HttpServletRequest httpServletRequest) {
|
||||
final String schemeHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_SCHEME_HTTP_HEADER, FORWARDED_PROTO_HTTP_HEADER);
|
||||
return StringUtils.isBlank(schemeHeaderValue) ? httpServletRequest.getScheme() : schemeHeaderValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the host based on considering proxy related headers first and falling back to the host of the servlet request.
|
||||
*
|
||||
* @param httpServletRequest the request
|
||||
* @return the determined host
|
||||
*/
|
||||
public static String determineProxiedHost(final HttpServletRequest httpServletRequest) {
|
||||
final String hostHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER);
|
||||
final String proxiedHost = determineProxiedHost(hostHeaderValue);
|
||||
return StringUtils.isBlank(proxiedHost) ? httpServletRequest.getServerName() : proxiedHost;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the host from the given header. The header value is intended to come from a header like X-ProxyHost or X-Forwarded-Host.
|
||||
*
|
||||
* @param hostHeaderValue the header value
|
||||
* @return the determined host, or null if a host can't be determined
|
||||
*/
|
||||
public static String determineProxiedHost(final String hostHeaderValue) {
|
||||
final String host;
|
||||
// check for a port in the proxied host header
|
||||
String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
|
||||
if (hostSplits.length >= 1 && hostSplits.length <= 2) {
|
||||
// zero or one occurrence of ':', this is an IPv4 address
|
||||
// strip off the port by reassigning host the 0th split
|
||||
host = hostSplits[0];
|
||||
} else if (hostSplits.length == 0) {
|
||||
// hostHeaderValue passed in was null, no splits
|
||||
host = null;
|
||||
} else {
|
||||
// hostHeaderValue has more than one occurrence of ":", IPv6 address
|
||||
host = hostHeaderValue;
|
||||
}
|
||||
return host;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the port based on first considering proxy related headers and falling back to the port of the servlet request.
|
||||
*
|
||||
* @param httpServletRequest the request
|
||||
* @return the determined port
|
||||
*/
|
||||
public static String determineProxiedPort(final HttpServletRequest httpServletRequest) {
|
||||
final String hostHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER);
|
||||
final String portHeaderValue = getFirstHeaderValue(httpServletRequest, PROXY_PORT_HTTP_HEADER, FORWARDED_PORT_HTTP_HEADER);
|
||||
|
||||
final String proxiedPort = determineProxiedPort(hostHeaderValue, portHeaderValue);
|
||||
return StringUtils.isBlank(proxiedPort) ? String.valueOf(httpServletRequest.getServerPort()) : proxiedPort;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the port based on the header values. The header values are intended to come from headers like X-ProxyHost/X-ProxyPort
|
||||
* or X-Forwarded-Host/X-Forwarded-Port.
|
||||
*
|
||||
* @param hostHeaderValue the host header value
|
||||
* @param portHeaderValue the host port value
|
||||
* @return the determined port, or null if one can't be determined
|
||||
*/
|
||||
public static String determineProxiedPort(final String hostHeaderValue, final String portHeaderValue) {
|
||||
final String port;
|
||||
// check for a port in the proxied host header
|
||||
String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
|
||||
// determine the proxied port
|
||||
final String portFromHostHeader;
|
||||
if (hostSplits.length == 2) {
|
||||
// if the port is specified in the proxied host header, it will be overridden by the
|
||||
// port specified in X-ProxyPort or X-Forwarded-Port
|
||||
portFromHostHeader = hostSplits[1];
|
||||
} else {
|
||||
portFromHostHeader = null;
|
||||
}
|
||||
if (StringUtils.isNotBlank(portFromHostHeader) && StringUtils.isNotBlank(portHeaderValue)) {
|
||||
logger.warn(String.format("The proxied host header contained a port, but was overridden by the proxied port header"));
|
||||
}
|
||||
port = StringUtils.isNotBlank(portHeaderValue) ? portHeaderValue : (StringUtils.isNotBlank(portFromHostHeader) ? portFromHostHeader : null);
|
||||
return port;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,8 +39,8 @@ import javax.ws.rs.core.UriBuilderException
|
|||
import java.security.cert.X509Certificate
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
class WebUtilsTest extends GroovyTestCase {
|
||||
private static final Logger logger = LoggerFactory.getLogger(WebUtilsTest.class)
|
||||
class WebUtilsGroovyTest extends GroovyTestCase {
|
||||
private static final Logger logger = LoggerFactory.getLogger(WebUtilsGroovyTest.class)
|
||||
|
||||
static final String PCP_HEADER = "X-ProxyContextPath"
|
||||
static final String FC_HEADER = "X-Forwarded-Context"
|
|
@ -0,0 +1,139 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.util;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class WebUtilsTest {
|
||||
|
||||
@Mock
|
||||
private HttpServletRequest request;
|
||||
|
||||
// -- scheme tests
|
||||
|
||||
@Test
|
||||
public void testDeterminedProxiedSchemeWhenNoHeaders() {
|
||||
when(request.getHeader(any())).thenReturn(null);
|
||||
when(request.getScheme()).thenReturn("https");
|
||||
assertEquals("https", WebUtils.determineProxiedScheme(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeterminedProxiedSchemeWhenXProxySchemeAvailable() {
|
||||
when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("http");
|
||||
assertEquals("http", WebUtils.determineProxiedScheme(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeterminedProxiedSchemeWhenXForwardedProtoAvailable() {
|
||||
when(request.getHeader(eq(WebUtils.FORWARDED_PROTO_HTTP_HEADER))).thenReturn("http");
|
||||
assertEquals("http", WebUtils.determineProxiedScheme(request));
|
||||
}
|
||||
|
||||
// -- host tests
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedHostWhenNoHeaders() {
|
||||
when(request.getHeader(any())).thenReturn(null);
|
||||
when(request.getServerName()).thenReturn("localhost");
|
||||
assertEquals("localhost", WebUtils.determineProxiedHost(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedHostWhenXProxyHostAvailable() {
|
||||
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host");
|
||||
assertEquals("x-proxy-host", WebUtils.determineProxiedHost(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedHostWhenXProxyHostAvailableWithPort() {
|
||||
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:443");
|
||||
assertEquals("x-proxy-host", WebUtils.determineProxiedHost(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedHostWhenXForwardedHostAvailable() {
|
||||
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host");
|
||||
assertEquals("x-forwarded-host", WebUtils.determineProxiedHost(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedHostWhenXForwardedHostAvailableWithPort() {
|
||||
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:443");
|
||||
assertEquals("x-forwarded-host", WebUtils.determineProxiedHost(request));
|
||||
}
|
||||
|
||||
// -- port tests
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedPortWhenNoHeaders() {
|
||||
when(request.getServerPort()).thenReturn(443);
|
||||
assertEquals("443", WebUtils.determineProxiedPort(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedPortWhenXProxyPortAvailable() {
|
||||
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host");
|
||||
when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("8443");
|
||||
assertEquals("8443", WebUtils.determineProxiedPort(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedPortWhenPortInXProxyHost() {
|
||||
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:1234");
|
||||
assertEquals("1234", WebUtils.determineProxiedPort(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedPortWhenXProxyPortOverridesXProxy() {
|
||||
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("x-proxy-host:1234");
|
||||
when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("8443");
|
||||
assertEquals("8443", WebUtils.determineProxiedPort(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedPortWhenXForwardedPortAvailable() {
|
||||
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host");
|
||||
when(request.getHeader(eq(WebUtils.FORWARDED_PORT_HTTP_HEADER))).thenReturn("8443");
|
||||
assertEquals("8443", WebUtils.determineProxiedPort(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedPortWhenPortInXForwardedHost() {
|
||||
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:1234");
|
||||
assertEquals("1234", WebUtils.determineProxiedPort(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDetermineProxiedPortWhenXForwardedPortOverridesXForwardedHost() {
|
||||
when(request.getHeader(eq(WebUtils.FORWARDED_HOST_HTTP_HEADER))).thenReturn("x-forwarded-host:1234");
|
||||
when(request.getHeader(eq(WebUtils.FORWARDED_PORT_HTTP_HEADER))).thenReturn("8443");
|
||||
assertEquals("8443", WebUtils.determineProxiedPort(request));
|
||||
}
|
||||
|
||||
}
|
|
@ -16,40 +16,8 @@
|
|||
*/
|
||||
package org.apache.nifi.web.api;
|
||||
|
||||
import static javax.ws.rs.core.Response.Status.NOT_FOUND;
|
||||
import static org.apache.commons.lang3.StringUtils.isEmpty;
|
||||
import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME;
|
||||
import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE;
|
||||
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Collections;
|
||||
import java.util.Enumeration;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.TreeMap;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.ws.rs.core.CacheControl;
|
||||
import javax.ws.rs.core.Context;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.MultivaluedHashMap;
|
||||
import javax.ws.rs.core.MultivaluedMap;
|
||||
import javax.ws.rs.core.Response;
|
||||
import javax.ws.rs.core.Response.ResponseBuilder;
|
||||
import javax.ws.rs.core.UriBuilder;
|
||||
import javax.ws.rs.core.UriBuilderException;
|
||||
import javax.ws.rs.core.UriInfo;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.nifi.authorization.AuthorizableLookup;
|
||||
import org.apache.nifi.authorization.AuthorizeAccess;
|
||||
|
@ -92,6 +60,45 @@ import org.apache.nifi.web.util.WebUtils;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.ws.rs.core.CacheControl;
|
||||
import javax.ws.rs.core.Context;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.MultivaluedHashMap;
|
||||
import javax.ws.rs.core.MultivaluedMap;
|
||||
import javax.ws.rs.core.Response;
|
||||
import javax.ws.rs.core.Response.ResponseBuilder;
|
||||
import javax.ws.rs.core.UriBuilder;
|
||||
import javax.ws.rs.core.UriBuilderException;
|
||||
import javax.ws.rs.core.UriInfo;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Collections;
|
||||
import java.util.Enumeration;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.TreeMap;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static javax.ws.rs.core.Response.Status.NOT_FOUND;
|
||||
import static org.apache.commons.lang3.StringUtils.isEmpty;
|
||||
import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME;
|
||||
import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE;
|
||||
import static org.apache.nifi.web.util.WebUtils.PROXY_SCHEME_HTTP_HEADER;
|
||||
import static org.apache.nifi.web.util.WebUtils.PROXY_HOST_HTTP_HEADER;
|
||||
import static org.apache.nifi.web.util.WebUtils.PROXY_PORT_HTTP_HEADER;
|
||||
import static org.apache.nifi.web.util.WebUtils.FORWARDED_PROTO_HTTP_HEADER;
|
||||
import static org.apache.nifi.web.util.WebUtils.FORWARDED_HOST_HTTP_HEADER;
|
||||
import static org.apache.nifi.web.util.WebUtils.FORWARDED_PORT_HTTP_HEADER;
|
||||
|
||||
/**
|
||||
* Base class for controllers.
|
||||
*/
|
||||
|
@ -101,19 +108,6 @@ public abstract class ApplicationResource {
|
|||
public static final String CLIENT_ID = "clientId";
|
||||
public static final String DISCONNECTED_NODE_ACKNOWLEDGED = "disconnectedNodeAcknowledged";
|
||||
|
||||
public static final String PROXY_SCHEME_HTTP_HEADER = "X-ProxyScheme";
|
||||
public static final String PROXY_HOST_HTTP_HEADER = "X-ProxyHost";
|
||||
public static final String PROXY_PORT_HTTP_HEADER = "X-ProxyPort";
|
||||
public static final String PROXY_CONTEXT_PATH_HTTP_HEADER = "X-ProxyContextPath";
|
||||
|
||||
public static final String FORWARDED_PROTO_HTTP_HEADER = "X-Forwarded-Proto";
|
||||
public static final String FORWARDED_HOST_HTTP_HEADER = "X-Forwarded-Host";
|
||||
public static final String FORWARDED_PORT_HTTP_HEADER = "X-Forwarded-Port";
|
||||
public static final String FORWARDED_CONTEXT_HTTP_HEADER = "X-Forwarded-Context";
|
||||
|
||||
// Traefik-specific headers
|
||||
public static final String FORWARDED_PREFIX_HTTP_HEADER = "X-Forwarded-Prefix";
|
||||
|
||||
protected static final String NON_GUARANTEED_ENDPOINT = "Note: This endpoint is subject to change as NiFi and it's REST API evolve.";
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(ApplicationResource.class);
|
||||
|
@ -157,8 +151,8 @@ public abstract class ApplicationResource {
|
|||
final String hostHeaderValue = getFirstHeaderValue(PROXY_HOST_HTTP_HEADER, FORWARDED_HOST_HTTP_HEADER);
|
||||
final String portHeaderValue = getFirstHeaderValue(PROXY_PORT_HTTP_HEADER, FORWARDED_PORT_HTTP_HEADER);
|
||||
|
||||
final String host = determineProxiedHost(hostHeaderValue);
|
||||
final String port = determineProxiedPort(hostHeaderValue, portHeaderValue);
|
||||
final String host = WebUtils.determineProxiedHost(hostHeaderValue);
|
||||
final String port = WebUtils.determineProxiedPort(hostHeaderValue, portHeaderValue);
|
||||
|
||||
// Catch header poisoning
|
||||
String allowedContextPaths = properties.getAllowedContextPaths();
|
||||
|
@ -194,44 +188,6 @@ public abstract class ApplicationResource {
|
|||
return uri;
|
||||
}
|
||||
|
||||
private String determineProxiedHost(String hostHeaderValue) {
|
||||
final String host;
|
||||
// check for a port in the proxied host header
|
||||
String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
|
||||
if (hostSplits.length >= 1 && hostSplits.length <= 2) {
|
||||
// zero or one occurrence of ':', this is an IPv4 address
|
||||
// strip off the port by reassigning host the 0th split
|
||||
host = hostSplits[0];
|
||||
} else if (hostSplits.length == 0) {
|
||||
// hostHeaderValue passed in was null, no splits
|
||||
host = null;
|
||||
} else {
|
||||
// hostHeaderValue has more than one occurrence of ":", IPv6 address
|
||||
host = hostHeaderValue;
|
||||
}
|
||||
return host;
|
||||
}
|
||||
|
||||
private String determineProxiedPort(String hostHeaderValue, String portHeaderValue) {
|
||||
final String port;
|
||||
// check for a port in the proxied host header
|
||||
String[] hostSplits = hostHeaderValue == null ? new String[] {} : hostHeaderValue.split(":");
|
||||
// determine the proxied port
|
||||
final String portFromHostHeader;
|
||||
if (hostSplits.length == 2) {
|
||||
// if the port is specified in the proxied host header, it will be overridden by the
|
||||
// port specified in X-ProxyPort or X-Forwarded-Port
|
||||
portFromHostHeader = hostSplits[1];
|
||||
} else {
|
||||
portFromHostHeader = null;
|
||||
}
|
||||
if (StringUtils.isNotBlank(portFromHostHeader) && StringUtils.isNotBlank(portHeaderValue)) {
|
||||
logger.warn(String.format("The proxied host header contained a port, but was overridden by the proxied port header"));
|
||||
}
|
||||
port = StringUtils.isNotBlank(portHeaderValue) ? portHeaderValue : (StringUtils.isNotBlank(portFromHostHeader) ? portFromHostHeader : null);
|
||||
return port;
|
||||
}
|
||||
|
||||
/**
|
||||
* Edit the response headers to indicating no caching.
|
||||
*
|
||||
|
@ -403,21 +359,7 @@ public abstract class ApplicationResource {
|
|||
* @return the value for the first key found
|
||||
*/
|
||||
private String getFirstHeaderValue(final String... keys) {
|
||||
if (keys == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
for (final String key : keys) {
|
||||
final String value = httpServletRequest.getHeader(key);
|
||||
|
||||
// if we found an entry for this key, return the value
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
// unable to find any matching keys
|
||||
return null;
|
||||
return WebUtils.getFirstHeaderValue(httpServletRequest, keys);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -16,31 +16,6 @@
|
|||
*/
|
||||
package org.apache.nifi.web.api;
|
||||
|
||||
import static org.apache.nifi.web.api.ApplicationResource.PROXY_HOST_HTTP_HEADER;
|
||||
import static org.apache.nifi.web.api.ApplicationResource.PROXY_PORT_HTTP_HEADER;
|
||||
import static org.apache.nifi.web.api.ApplicationResource.PROXY_SCHEME_HTTP_HEADER;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.net.URL;
|
||||
import javax.servlet.ServletContext;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.ws.rs.core.Response;
|
||||
import javax.ws.rs.core.StreamingOutput;
|
||||
import javax.ws.rs.core.UriBuilder;
|
||||
import javax.ws.rs.core.UriInfo;
|
||||
import org.apache.nifi.authorization.AuthorizableLookup;
|
||||
import org.apache.nifi.authorization.resource.ResourceType;
|
||||
import org.apache.nifi.remote.HttpRemoteSiteListener;
|
||||
|
@ -58,6 +33,32 @@ import org.apache.nifi.web.api.entity.TransactionResultEntity;
|
|||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
import javax.servlet.ServletContext;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.ws.rs.core.Response;
|
||||
import javax.ws.rs.core.StreamingOutput;
|
||||
import javax.ws.rs.core.UriBuilder;
|
||||
import javax.ws.rs.core.UriInfo;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.net.URL;
|
||||
|
||||
import static org.apache.nifi.web.util.WebUtils.PROXY_HOST_HTTP_HEADER;
|
||||
import static org.apache.nifi.web.util.WebUtils.PROXY_PORT_HTTP_HEADER;
|
||||
import static org.apache.nifi.web.util.WebUtils.PROXY_SCHEME_HTTP_HEADER;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class TestDataTransferResource {
|
||||
|
||||
@BeforeClass
|
||||
|
|
|
@ -17,17 +17,19 @@
|
|||
package org.apache.nifi.web.security.saml.impl;
|
||||
|
||||
import org.apache.nifi.web.security.saml.NiFiSAMLContextProvider;
|
||||
import org.apache.nifi.web.security.saml.impl.http.HttpServletRequestWithParameters;
|
||||
import org.apache.nifi.web.security.saml.impl.http.ProxyAwareHttpServletRequestWrapper;
|
||||
import org.opensaml.saml2.metadata.provider.MetadataProviderException;
|
||||
import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
|
||||
import org.opensaml.ws.transport.http.HttpServletResponseAdapter;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.security.saml.context.SAMLContextProviderImpl;
|
||||
import org.springframework.security.saml.context.SAMLMessageContext;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletRequestWrapper;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
|
@ -35,6 +37,8 @@ import java.util.Map;
|
|||
*/
|
||||
public class NiFiSAMLContextProviderImpl extends SAMLContextProviderImpl implements NiFiSAMLContextProvider {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(NiFiSAMLContextProviderImpl.class);
|
||||
|
||||
@Override
|
||||
public SAMLMessageContext getLocalEntity(HttpServletRequest request, HttpServletResponse response, Map<String, String> parameters)
|
||||
throws MetadataProviderException {
|
||||
|
@ -60,55 +64,20 @@ public class NiFiSAMLContextProviderImpl extends SAMLContextProviderImpl impleme
|
|||
}
|
||||
|
||||
protected void populateGenericContext(HttpServletRequest request, HttpServletResponse response, Map<String, String> parameters, SAMLMessageContext context) {
|
||||
HttpServletRequestAdapter inTransport = new HttpServletRequestWithParameters(request, parameters);
|
||||
HttpServletResponseAdapter outTransport = new HttpServletResponseAdapter(response, request.isSecure());
|
||||
HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
|
||||
LOGGER.debug("Populating SAMLContext - request wrapper URL is [{}]", requestWrapper.getRequestURL().toString());
|
||||
|
||||
HttpServletRequestAdapter inTransport = new HttpServletRequestWithParameters(requestWrapper, parameters);
|
||||
HttpServletResponseAdapter outTransport = new HttpServletResponseAdapter(response, requestWrapper.isSecure());
|
||||
|
||||
// Store attribute which cannot be located from InTransport directly
|
||||
request.setAttribute(org.springframework.security.saml.SAMLConstants.LOCAL_CONTEXT_PATH, request.getContextPath());
|
||||
requestWrapper.setAttribute(org.springframework.security.saml.SAMLConstants.LOCAL_CONTEXT_PATH, requestWrapper.getContextPath());
|
||||
|
||||
context.setMetadataProvider(metadata);
|
||||
context.setInboundMessageTransport(inTransport);
|
||||
context.setOutboundMessageTransport(outTransport);
|
||||
|
||||
context.setMessageStorage(storageFactory.getMessageStorage(request));
|
||||
context.setMessageStorage(storageFactory.getMessageStorage(requestWrapper));
|
||||
}
|
||||
|
||||
/**
|
||||
* Extends the HttpServletRequestAdapter with a provided set of parameters.
|
||||
*/
|
||||
private static class HttpServletRequestWithParameters extends HttpServletRequestAdapter {
|
||||
|
||||
private final Map<String, String> providedParameters;
|
||||
|
||||
public HttpServletRequestWithParameters(HttpServletRequest request, Map<String,String> providedParameters) {
|
||||
super(request);
|
||||
this.providedParameters = providedParameters == null ? Collections.emptyMap() : providedParameters;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getParameterValue(String name) {
|
||||
String value = super.getParameterValue(name);
|
||||
if (value == null) {
|
||||
value = providedParameters.get(name);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getParameterValues(String name) {
|
||||
List<String> combinedValues = new ArrayList<>();
|
||||
|
||||
List<String> initialValues = super.getParameterValues(name);
|
||||
if (initialValues != null) {
|
||||
combinedValues.addAll(initialValues);
|
||||
}
|
||||
|
||||
String providedValue = providedParameters.get(name);
|
||||
if (providedValue != null) {
|
||||
combinedValues.add(providedValue);
|
||||
}
|
||||
|
||||
return combinedValues;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.security.saml.impl.http;
|
||||
|
||||
import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Extends the HttpServletRequestAdapter with a provided set of parameters.
|
||||
*/
|
||||
public class HttpServletRequestWithParameters extends HttpServletRequestAdapter {
|
||||
|
||||
private final Map<String, String> providedParameters;
|
||||
|
||||
public HttpServletRequestWithParameters(final HttpServletRequest request, final Map<String, String> providedParameters) {
|
||||
super(request);
|
||||
this.providedParameters = providedParameters == null ? Collections.emptyMap() : providedParameters;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getParameterValue(final String name) {
|
||||
String value = super.getParameterValue(name);
|
||||
if (value == null) {
|
||||
value = providedParameters.get(name);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getParameterValues(final String name) {
|
||||
List<String> combinedValues = new ArrayList<>();
|
||||
|
||||
List<String> initialValues = super.getParameterValues(name);
|
||||
if (initialValues != null) {
|
||||
combinedValues.addAll(initialValues);
|
||||
}
|
||||
|
||||
String providedValue = providedParameters.get(name);
|
||||
if (providedValue != null) {
|
||||
combinedValues.add(providedValue);
|
||||
}
|
||||
|
||||
return combinedValues;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.security.saml.impl.http;
|
||||
|
||||
import org.apache.nifi.web.util.WebUtils;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletRequestWrapper;
|
||||
|
||||
/**
|
||||
* Extension of HttpServletRequestWrapper that respects proxied/forwarded header values for scheme, host, port, and context path.
|
||||
* <p>
|
||||
* If NiFi generates a SAML request using proxied values so that the IDP redirects back through the proxy, then this is needed
|
||||
* so that when Open SAML checks the Destination in the SAML response, it will match with the values here.
|
||||
* <p>
|
||||
* This class is based on SAMLContextProviderLB from spring-security-saml.
|
||||
*/
|
||||
public class ProxyAwareHttpServletRequestWrapper extends HttpServletRequestWrapper {
|
||||
|
||||
private final String scheme;
|
||||
private final String serverName;
|
||||
private final int serverPort;
|
||||
private final String proxyContextPath;
|
||||
private final String contextPath;
|
||||
|
||||
public ProxyAwareHttpServletRequestWrapper(final HttpServletRequest request) {
|
||||
super(request);
|
||||
this.scheme = WebUtils.determineProxiedScheme(request);
|
||||
this.serverName = WebUtils.determineProxiedHost(request);
|
||||
this.serverPort = Integer.valueOf(WebUtils.determineProxiedPort(request));
|
||||
|
||||
final String tempProxyContextPath = WebUtils.normalizeContextPath(WebUtils.determineContextPath(request));
|
||||
this.proxyContextPath = tempProxyContextPath.equals("/") ? "" : tempProxyContextPath;
|
||||
|
||||
this.contextPath = request.getContextPath();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getContextPath() {
|
||||
return contextPath;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getScheme() {
|
||||
return scheme;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getServerName() {
|
||||
return serverName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getServerPort() {
|
||||
return serverPort;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getRequestURI() {
|
||||
StringBuilder sb = new StringBuilder(contextPath);
|
||||
sb.append(getServletPath());
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public StringBuffer getRequestURL() {
|
||||
StringBuffer sb = new StringBuffer();
|
||||
sb.append(scheme).append("://").append(serverName);
|
||||
sb.append(":").append(serverPort);
|
||||
sb.append(proxyContextPath);
|
||||
sb.append(contextPath);
|
||||
sb.append(getServletPath());
|
||||
if (getPathInfo() != null) sb.append(getPathInfo());
|
||||
return sb;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSecure() {
|
||||
return "https".equalsIgnoreCase(scheme);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.security.saml.impl.http;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class TestHttpServletRequestWithParameters {
|
||||
|
||||
@Mock
|
||||
private HttpServletRequest request;
|
||||
|
||||
@Test
|
||||
public void testGetParameterValueWhenNoExtraParameters() {
|
||||
final String paramName = "fooParam";
|
||||
final String paramValue = "fooValue";
|
||||
when(request.getParameter(eq(paramName))).thenReturn(paramValue);
|
||||
|
||||
final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, Collections.emptyMap());
|
||||
final String result = requestAdapter.getParameterValue(paramName);
|
||||
assertEquals(paramValue, result);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetParameterValueWhenExtraParameters() {
|
||||
final String paramName = "fooParam";
|
||||
final String paramValue = "fooValue";
|
||||
|
||||
final Map<String,String> extraParams = new HashMap<>();
|
||||
extraParams.put(paramName, paramValue);
|
||||
|
||||
when(request.getParameter(any())).thenReturn(null);
|
||||
|
||||
final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, extraParams);
|
||||
final String result = requestAdapter.getParameterValue(paramName);
|
||||
assertEquals(paramValue, result);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetParameterValuesWhenNoExtraParameters() {
|
||||
final String paramName = "fooParam";
|
||||
final String paramValue = "fooValue";
|
||||
when(request.getParameterValues(eq(paramName))).thenReturn(new String[] {paramValue});
|
||||
|
||||
final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, Collections.emptyMap());
|
||||
final List<String> results = requestAdapter.getParameterValues(paramName);
|
||||
assertEquals(1, results.size());
|
||||
assertEquals(paramValue, results.get(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetParameterValuesWhenExtraParameters() {
|
||||
final String paramName = "fooParam";
|
||||
final String paramValue1 = "fooValue1";
|
||||
when(request.getParameterValues(eq(paramName))).thenReturn(new String[] {paramValue1});
|
||||
|
||||
final String paramValue2 = "fooValue2";
|
||||
final Map<String,String> extraParams = new HashMap<>();
|
||||
extraParams.put(paramName, paramValue2);
|
||||
|
||||
final HttpServletRequestAdapter requestAdapter = new HttpServletRequestWithParameters(request, extraParams);
|
||||
final List<String> results = requestAdapter.getParameterValues(paramName);
|
||||
assertEquals(2, results.size());
|
||||
assertTrue(results.contains(paramValue1));
|
||||
assertTrue(results.contains(paramValue2));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.nifi.web.security.saml.impl.http;
|
||||
|
||||
import org.apache.nifi.web.util.WebUtils;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletRequestWrapper;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class TestProxyAwareHttpServletRequestWrapper {
|
||||
|
||||
@Mock
|
||||
private HttpServletRequest request;
|
||||
|
||||
@Test
|
||||
public void testWhenNotProxied() {
|
||||
when(request.getScheme()).thenReturn("https");
|
||||
when(request.getServerName()).thenReturn("localhost");
|
||||
when(request.getServerPort()).thenReturn(8443);
|
||||
when(request.getContextPath()).thenReturn("/nifi-api");
|
||||
when(request.getServletPath()).thenReturn("/access/saml/metadata");
|
||||
when(request.getHeader(any())).thenReturn(null);
|
||||
|
||||
final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
|
||||
assertEquals("https://localhost:8443/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWhenProxied() {
|
||||
when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("https");
|
||||
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("proxy-host");
|
||||
when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("443");
|
||||
when(request.getHeader(eq(WebUtils.PROXY_CONTEXT_PATH_HTTP_HEADER))).thenReturn("/proxy-context");
|
||||
when(request.getContextPath()).thenReturn("/nifi-api");
|
||||
when(request.getServletPath()).thenReturn("/access/saml/metadata");
|
||||
|
||||
final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
|
||||
assertEquals("https://proxy-host:443/proxy-context/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWhenProxiedWithEmptyProxyContextPath() {
|
||||
when(request.getHeader(eq(WebUtils.PROXY_SCHEME_HTTP_HEADER))).thenReturn("https");
|
||||
when(request.getHeader(eq(WebUtils.PROXY_HOST_HTTP_HEADER))).thenReturn("proxy-host");
|
||||
when(request.getHeader(eq(WebUtils.PROXY_PORT_HTTP_HEADER))).thenReturn("443");
|
||||
when(request.getHeader(eq(WebUtils.PROXY_CONTEXT_PATH_HTTP_HEADER))).thenReturn("/");
|
||||
when(request.getContextPath()).thenReturn("/nifi-api");
|
||||
when(request.getServletPath()).thenReturn("/access/saml/metadata");
|
||||
|
||||
final HttpServletRequestWrapper requestWrapper = new ProxyAwareHttpServletRequestWrapper(request);
|
||||
assertEquals("https://proxy-host:443/nifi-api/access/saml/metadata", requestWrapper.getRequestURL().toString());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue