NIFI-13195 Corrected Replicated Request Header Removal for HTTP/2 (#8789)

- Added findHeaderName method to look for header names regardless of casing

This closes #8789
This commit is contained in:
David Handermann 2024-05-09 17:02:44 -05:00 committed by GitHub
parent 1592b98298
commit e005d5f8c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 54 additions and 23 deletions

View File

@ -69,6 +69,7 @@ import java.util.List;
import java.util.LongSummaryStatistics;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@ -93,6 +94,9 @@ public class ThreadPoolRequestReplicator implements RequestReplicator {
private static final Logger logger = LoggerFactory.getLogger(ThreadPoolRequestReplicator.class);
private static final Pattern SNIPPET_URI_PATTERN = Pattern.compile("/nifi-api/snippets/[a-f0-9\\-]{36}");
private static final String COOKIE_HEADER = "Cookie";
private static final String HOST_HEADER = "Host";
private final int maxConcurrentRequests; // maximum number of concurrent requests
private final HttpResponseMapper responseMapper;
private final EventReporter eventReporter;
@ -100,8 +104,8 @@ public class ThreadPoolRequestReplicator implements RequestReplicator {
private final ClusterCoordinator clusterCoordinator;
private final NiFiProperties nifiProperties;
private ThreadPoolExecutor executorService;
private ScheduledExecutorService maintenanceExecutor;
private final ThreadPoolExecutor executorService;
private final ScheduledExecutorService maintenanceExecutor;
private final ConcurrentMap<String, StandardAsyncClusterResponse> responseMap = new ConcurrentHashMap<>();
private final ConcurrentMap<NodeIdentifier, AtomicInteger> sequentialLongRequestCounts = new ConcurrentHashMap<>();
@ -110,7 +114,7 @@ public class ThreadPoolRequestReplicator implements RequestReplicator {
private final Lock readLock = rwLock.readLock();
private final Lock writeLock = rwLock.writeLock();
private HttpReplicationClient httpClient;
private final HttpReplicationClient httpClient;
/**
@ -151,17 +155,14 @@ public class ThreadPoolRequestReplicator implements RequestReplicator {
executorService = new ThreadPoolExecutor(maxPoolSize, maxPoolSize, 5, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), threadFactory);
executorService.allowCoreThreadTimeOut(true);
maintenanceExecutor = Executors.newScheduledThreadPool(1, new ThreadFactory() {
@Override
public Thread newThread(final Runnable r) {
final Thread t = Executors.defaultThreadFactory().newThread(r);
t.setDaemon(true);
t.setName(ThreadPoolRequestReplicator.class.getSimpleName() + " Maintenance Thread");
return t;
}
maintenanceExecutor = Executors.newScheduledThreadPool(1, r -> {
final Thread t = Executors.defaultThreadFactory().newThread(r);
t.setDaemon(true);
t.setName(ThreadPoolRequestReplicator.class.getSimpleName() + " Maintenance Thread");
return t;
});
maintenanceExecutor.scheduleWithFixedDelay(() -> purgeExpiredRequests(), 1, 1, TimeUnit.SECONDS);
maintenanceExecutor.scheduleWithFixedDelay(this::purgeExpiredRequests, 1, 1, TimeUnit.SECONDS);
}
@Override
@ -187,7 +188,7 @@ public class ThreadPoolRequestReplicator implements RequestReplicator {
final List<NodeIdentifier> connecting = stateMap.get(NodeConnectionState.CONNECTING);
if (connecting != null && !connecting.isEmpty()) {
if (connecting.size() == 1) {
throw new ConnectingNodeMutableRequestException("Node " + connecting.iterator().next() + " is currently connecting");
throw new ConnectingNodeMutableRequestException("Node " + connecting.getFirst() + " is currently connecting");
} else {
throw new ConnectingNodeMutableRequestException(connecting.size() + " Nodes are currently connecting");
}
@ -248,7 +249,7 @@ public class ThreadPoolRequestReplicator implements RequestReplicator {
// remove the access token if present, since the user is already authenticated... authorization
// will happen when the request is replicated using the proxy chain above
headers.remove(SecurityHeader.AUTHORIZATION.getHeader());
removeHeader(headers, SecurityHeader.AUTHORIZATION.getHeader());
// if knox sso cookie name is set, remove any authentication cookie since this user is already authenticated
// and will be included in the proxied entities chain above... authorization will happen when the
@ -258,7 +259,7 @@ public class ThreadPoolRequestReplicator implements RequestReplicator {
removeCookie(headers, SecurityCookieName.REQUEST_TOKEN.getName());
// remove the host header
headers.remove("Host");
removeHeader(headers, HOST_HEADER);
}
@Override
@ -471,7 +472,7 @@ public class ThreadPoolRequestReplicator implements RequestReplicator {
synchronized (monitor) {
monitor.notify();
}
logger.debug("Notified monitor {} because request {} {} has failed with Throwable {}", monitor, method, uri, t);
logger.debug("Notified monitor {} because request {} {} has failed", monitor, method, uri, t);
}
if (response != null) {
@ -895,19 +896,49 @@ public class ThreadPoolRequestReplicator implements RequestReplicator {
return responseMap.size();
}
private void removeCookie(Map<String, String> headers, final String cookieName) {
if (headers.containsKey("Cookie") && StringUtils.isNotBlank(cookieName)) {
final String rawCookies = headers.get("Cookie");
private void removeCookie(final Map<String, String> headers, final String cookieName) {
final Optional<String> cookieHeaderNameFound = findHeaderName(headers, COOKIE_HEADER);
if (cookieHeaderNameFound.isPresent()) {
final String cookieHeaderName = cookieHeaderNameFound.get();
final String rawCookies = headers.get(cookieHeaderName);
final String[] rawCookieParts = rawCookies.split(";");
final Set<String> filteredCookieParts = Stream.of(rawCookieParts).map(String::trim).filter(cookie -> !cookie.startsWith(cookieName + "=")).collect(Collectors.toSet());
// if that was the only cookie, remove it
if (filteredCookieParts.isEmpty()) {
headers.remove("Cookie");
headers.remove(cookieHeaderName);
} else {
// otherwise rebuild the cookies without the knox token
headers.put("Cookie", StringUtils.join(filteredCookieParts, "; "));
final String filteredCookies = StringUtils.join(filteredCookieParts, "; ");
headers.put(cookieHeaderName, filteredCookies);
}
}
}
private void removeHeader(final Map<String, String> headers, final String headerNameSearch) {
final Optional<String> headerNameFound = findHeaderName(headers, headerNameSearch);
headerNameFound.ifPresent(headers::remove);
}
/**
* Find HTTP Header name in map regardless of case since HTTP/1.1 capitalizes headers but HTTP/2 returns lowercased headers
*
* @param headers Map of header name to value
* @param headerName Header name to be found
* @return Optional match with header name from map of headers
*/
private Optional<String> findHeaderName(final Map<String, String> headers, final String headerName) {
final Optional<String> headerNameFound;
if (headerName == null || headerName.isBlank()) {
headerNameFound = Optional.empty();
} else {
headerNameFound = headers.keySet()
.stream()
.filter(headerName::equalsIgnoreCase)
.findFirst();
}
return headerNameFound;
}
}