NIFI-8639: Add incoming flowfile to ConnectWebSocket processor to configure custom headers and dynamic URL through flowfile attributes in JettyWebSocketClient service.

This resolves #5130.

Signed-off-by: Tamas Palfy <tamas.bertalan.palfy@gmail.com>
This commit is contained in:
Lehel Boér 2021-06-21 15:39:37 +02:00 committed by Tamas Palfy
parent af0f3403a5
commit b99e7fc560
11 changed files with 359 additions and 40 deletions

View File

@ -52,5 +52,11 @@
<version>1.15.0-SNAPSHOT</version> <version>1.15.0-SNAPSHOT</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-websocket-services-jetty</artifactId>
<version>1.14.0-SNAPSHOT</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -43,6 +43,7 @@ import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID;
@ -89,6 +90,7 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF
logger = getLogger(); logger = getLogger();
} }
@FunctionalInterface
public interface WebSocketFunction { public interface WebSocketFunction {
void execute(final WebSocketService webSocketService) throws IOException, WebSocketConfigurationException; void execute(final WebSocketService webSocketService) throws IOException, WebSocketConfigurationException;
} }
@ -118,12 +120,24 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF
// @OnScheduled can not report error messages well on bulletin since it's an async method. // @OnScheduled can not report error messages well on bulletin since it's an async method.
// So, let's do it in onTrigger(). // So, let's do it in onTrigger().
public void onWebSocketServiceReady(final WebSocketService webSocketService) throws IOException { public void onWebSocketServiceReady(final WebSocketService webSocketService, final ProcessContext context) throws IOException {
if (webSocketService instanceof WebSocketClientService) { if (webSocketService instanceof WebSocketClientService) {
// If it's a ws client, then connect to the remote here. // If it's a ws client, then connect to the remote here.
// Otherwise, ws server is already started at WebSocketServerService // Otherwise, ws server is already started at WebSocketServerService
((WebSocketClientService) webSocketService).connect(endpointId); WebSocketClientService webSocketClientService = (WebSocketClientService) webSocketService;
if (context.hasIncomingConnection()) {
final ProcessSession session = processSessionFactory.createSession();
final FlowFile flowFile = session.get();
final Map<String, String> attributes = flowFile.getAttributes();
try {
webSocketClientService.connect(endpointId, attributes);
} finally {
session.remove(flowFile);
session.commitAsync();
}
} else {
webSocketClientService.connect(endpointId);
}
} }
} }
@ -137,6 +151,7 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF
} }
protected abstract WebSocketService getWebSocketService(final ProcessContext context); protected abstract WebSocketService getWebSocketService(final ProcessContext context);
protected abstract String getEndpointId(final ProcessContext context); protected abstract String getEndpointId(final ProcessContext context);
protected boolean isProcessorRegisteredToService() { protected boolean isProcessorRegisteredToService() {
@ -146,7 +161,7 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF
} }
@OnStopped @OnStopped
public void onStopped(final ProcessContext context) throws IOException { public void onStopped(final ProcessContext context) {
deregister(); deregister();
} }
@ -165,20 +180,29 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF
} }
@Override @Override
public final void onTrigger(final ProcessContext context, final ProcessSessionFactory sessionFactory) throws ProcessException { public final void onTrigger(final ProcessContext context, final ProcessSessionFactory sessionFactory) {
if (processSessionFactory == null) { if (processSessionFactory == null) {
processSessionFactory = sessionFactory; processSessionFactory = sessionFactory;
} }
if (!isProcessorRegisteredToService()) { if (!isProcessorRegisteredToService()) {
try { try {
registerProcessorToService(context, webSocketService -> onWebSocketServiceReady(webSocketService)); registerProcessorToService(context, webSocketService -> onWebSocketServiceReady(webSocketService, context));
} catch (IOException | WebSocketConfigurationException e) { } catch (IOException | WebSocketConfigurationException e) {
// Deregister processor if it failed so that it can retry next onTrigger. // Deregister processor if it failed so that it can retry next onTrigger.
deregister(); deregister();
context.yield(); context.yield();
throw new ProcessException("Failed to register processor to WebSocket service due to: " + e, e); throw new ProcessException("Failed to register processor to WebSocket service due to: " + e, e);
} }
} else {
try {
onWebSocketServiceReady(webSocketService, context);
} catch (IOException e) {
deregister();
context.yield();
throw new ProcessException("Failed to renew session and connect to WebSocket service due to: " + e, e);
}
} }
context.yield();//nothing really to do here since handling WebSocket messages is done at ControllerService. context.yield();//nothing really to do here since handling WebSocket messages is done at ControllerService.
@ -206,9 +230,9 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF
final byte[] payload = incomingMessage.getPayload(); final byte[] payload = incomingMessage.getPayload();
if (payload != null) { if (payload != null) {
messageFlowFile = session.write(messageFlowFile, out -> { messageFlowFile = session.write(messageFlowFile, out ->
out.write(payload, incomingMessage.getOffset(), incomingMessage.getLength()); out.write(payload, incomingMessage.getOffset(), incomingMessage.getLength())
}); );
} }
session.getProvenanceReporter().receive(messageFlowFile, getTransitUri(sessionInfo)); session.getProvenanceReporter().receive(messageFlowFile, getTransitUri(sessionInfo));
@ -216,7 +240,7 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF
if (incomingMessage instanceof WebSocketConnectedMessage) { if (incomingMessage instanceof WebSocketConnectedMessage) {
session.transfer(messageFlowFile, REL_CONNECTED); session.transfer(messageFlowFile, REL_CONNECTED);
} else { } else {
switch (messageType) { switch (Objects.requireNonNull(messageType)) {
case TEXT: case TEXT:
session.transfer(messageFlowFile, REL_MESSAGE_TEXT); session.transfer(messageFlowFile, REL_MESSAGE_TEXT);
break; break;
@ -233,6 +257,6 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF
} }
} }
abstract protected String getTransitUri(final WebSocketSessionInfo sessionInfo); protected abstract String getTransitUri(final WebSocketSessionInfo sessionInfo);
} }

View File

@ -44,7 +44,7 @@ import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_SESSION_ID; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_SESSION_ID;
@Tags({"subscribe", "WebSocket", "consume", "listen"}) @Tags({"subscribe", "WebSocket", "consume", "listen"})
@InputRequirement(InputRequirement.Requirement.INPUT_FORBIDDEN) @InputRequirement(InputRequirement.Requirement.INPUT_ALLOWED)
@TriggerSerially @TriggerSerially
@CapabilityDescription("Acts as a WebSocket client endpoint to interact with a remote WebSocket server." + @CapabilityDescription("Acts as a WebSocket client endpoint to interact with a remote WebSocket server." +
" FlowFiles are transferred to downstream relationships according to received message types" + " FlowFiles are transferred to downstream relationships according to received message types" +

View File

@ -0,0 +1,47 @@
<!DOCTYPE html>
<html lang="en">
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<head>
<meta charset="utf-8"/>
<title>ConnectWebSocket</title>
<link rel="stylesheet" href="../../../../../css/component-usage.css" type="text/css"/>
</head>
<body>
<h2>Summary</h2>
<p>
This processor acts as a WebSocket client endpoint to interact with a remote WebSocket server.
It is capable of receiving messages from a websocket server and it transfers them to downstream relationships
according to the received message types.
</p>
<p>
The processor may have an incoming relationship, in which case flowfile attributes are passed down to its WebSocket Client Service.
This can be used to fine-tune the connection configuration (url and headers for example). For example "dynamic.url =
currentValue" flowfile attribute can be referenced in the WebSocket Client Service with the ${dynamic.url} expression.
</p>
<p>
You can define custom websocket headers in the incoming flowfile as additional attributes. The attribute key
shall start with "header." and continue with they header key. For example: "header.Authorization". The attribute
value will be the corresponding header value.
<ol>
<li>header.Autorization | Basic base64UserNamePassWord</li>
<li>header.Content-Type | application, audio, example</li>
</ol>
<p>
For multiple header values provide a comma separated list.
</p>
</body>
</html>

View File

@ -20,6 +20,8 @@ import org.apache.nifi.processor.ProcessSessionFactory;
import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.Relationship;
import org.apache.nifi.provenance.ProvenanceEventRecord; import org.apache.nifi.provenance.ProvenanceEventRecord;
import org.apache.nifi.provenance.ProvenanceEventType; import org.apache.nifi.provenance.ProvenanceEventType;
import org.apache.nifi.remote.io.socket.NetworkUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.MockFlowFile; import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.MockProcessSession; import org.apache.nifi.util.MockProcessSession;
import org.apache.nifi.util.SharedSessionState; import org.apache.nifi.util.SharedSessionState;
@ -29,9 +31,13 @@ import org.apache.nifi.websocket.AbstractWebSocketSession;
import org.apache.nifi.websocket.WebSocketClientService; import org.apache.nifi.websocket.WebSocketClientService;
import org.apache.nifi.websocket.WebSocketMessage; import org.apache.nifi.websocket.WebSocketMessage;
import org.apache.nifi.websocket.WebSocketSession; import org.apache.nifi.websocket.WebSocketSession;
import org.apache.nifi.websocket.jetty.JettyWebSocketClient;
import org.apache.nifi.websocket.jetty.JettyWebSocketServer;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -51,6 +57,8 @@ public class TestConnectWebSocket extends TestListenWebSocket {
@Test @Test
public void testSuccess() throws Exception { public void testSuccess() throws Exception {
final TestRunner runner = TestRunners.newTestRunner(ConnectWebSocket.class); final TestRunner runner = TestRunners.newTestRunner(ConnectWebSocket.class);
runner.setIncomingConnection(false);
final ConnectWebSocket processor = (ConnectWebSocket) runner.getProcessor(); final ConnectWebSocket processor = (ConnectWebSocket) runner.getProcessor();
final SharedSessionState sharedSessionState = new SharedSessionState(processor, new AtomicLong(0)); final SharedSessionState sharedSessionState = new SharedSessionState(processor, new AtomicLong(0));
@ -121,4 +129,63 @@ public class TestConnectWebSocket extends TestListenWebSocket {
assertTrue(provenanceEvents.stream().allMatch(event -> ProvenanceEventType.RECEIVE.equals(event.getEventType()))); assertTrue(provenanceEvents.stream().allMatch(event -> ProvenanceEventType.RECEIVE.equals(event.getEventType())));
} }
@Test
public void testDynamicUrlsParsedFromFlowFileAndAbleToConnect() throws InitializationException {
// Start websocket server
final int port = NetworkUtils.availablePort();
TestRunner webSocketListener = getListenWebSocket(port);
webSocketListener.run(1, false);
final TestRunner runner = TestRunners.newTestRunner(ConnectWebSocket.class);
final String serviceId = "ws-service";
final String endpointId = "client-1";
Map<String, String> attributes = new HashMap<>();
attributes.put("dynamicUrlPart", "test");
MockFlowFile flowFile = new MockFlowFile(1L);
flowFile.putAttributes(attributes);
runner.enqueue(flowFile);
attributes.put("dynamicUrlPart", "test2");
MockFlowFile flowFileWithWrongUrl = new MockFlowFile(2L);
flowFileWithWrongUrl.putAttributes(attributes);
runner.enqueue(flowFileWithWrongUrl);
JettyWebSocketClient service = new JettyWebSocketClient();
runner.addControllerService(serviceId, service);
runner.setProperty(service, JettyWebSocketClient.WS_URI, String.format("ws://localhost:%s/${dynamicUrlPart}", port));
runner.enableControllerService(service);
runner.setProperty(ConnectWebSocket.PROP_WEBSOCKET_CLIENT_SERVICE, serviceId);
runner.setProperty(ConnectWebSocket.PROP_WEBSOCKET_CLIENT_ID, endpointId);
runner.run(1, false);
final List<MockFlowFile> flowFilesForRelationship = runner.getFlowFilesForRelationship(ConnectWebSocket.REL_CONNECTED);
assertEquals(1, flowFilesForRelationship.size());
final AssertionError assertionError = Assert.assertThrows(AssertionError.class, () -> runner.run(1));
assertTrue(assertionError.getCause().getLocalizedMessage().contains("Failed to renew session and connect to WebSocket service"));
runner.stop();
webSocketListener.stop();
}
private TestRunner getListenWebSocket(final int port) throws InitializationException {
final TestRunner runner = TestRunners.newTestRunner(ListenWebSocket.class);
final String serviceId = "ws-server-service";
JettyWebSocketServer service = new JettyWebSocketServer();
runner.addControllerService(serviceId, service);
runner.setProperty(service, JettyWebSocketServer.LISTEN_PORT, String.valueOf(port));
runner.enableControllerService(service);
runner.setProperty(ListenWebSocket.PROP_WEBSOCKET_SERVER_SERVICE, serviceId);
runner.setProperty(ListenWebSocket.PROP_SERVER_URL_PATH, "/test");
return runner;
}
} }

View File

@ -20,6 +20,7 @@ import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.controller.ConfigurationContext; import org.apache.nifi.controller.ConfigurationContext;
import java.io.IOException; import java.io.IOException;
import java.util.Map;
/** /**
* Control a WebSocket client instance. * Control a WebSocket client instance.
@ -34,6 +35,10 @@ public interface WebSocketClientService extends WebSocketService {
void connect(final String clientId) throws IOException; void connect(final String clientId) throws IOException;
default void connect(final String clientId, final Map<String, String> flowFileAttributes) throws IOException {
connect(clientId);
}
String getTargetUri(); String getTargetUri();
} }

View File

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dto;
import java.util.Map;
public class SessionInfo {
private final String sessionId;
private final Map<String, String> flowFileAttributes;
public SessionInfo(final String sessionId, final Map<String, String> flowFileAttributes) {
this.sessionId = sessionId;
this.flowFileAttributes = flowFileAttributes;
}
public String getSessionId() {
return sessionId;
}
public Map<String, String> getFlowFileAttributes() {
return flowFileAttributes;
}
}

View File

@ -16,6 +16,7 @@
*/ */
package org.apache.nifi.websocket.jetty; package org.apache.nifi.websocket.jetty;
import dto.SessionInfo;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.annotation.lifecycle.OnDisabled; import org.apache.nifi.annotation.lifecycle.OnDisabled;
@ -27,6 +28,7 @@ import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.controller.ConfigurationContext; import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.expression.ExpressionLanguageScope; import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.ssl.SSLContextService; import org.apache.nifi.ssl.SSLContextService;
import org.apache.nifi.util.StringUtils; import org.apache.nifi.util.StringUtils;
@ -40,14 +42,17 @@ import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.client.WebSocketClient;
import util.HeaderMapExtractor;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Base64; import java.util.Base64;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
@ -68,7 +73,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
.displayName("WebSocket URI") .displayName("WebSocket URI")
.description("The WebSocket URI this client connects to.") .description("The WebSocket URI this client connects to.")
.required(true) .required(true)
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY) .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
.addValidator(StandardValidators.URI_VALIDATOR) .addValidator(StandardValidators.URI_VALIDATOR)
.addValidator((subject, input, context) -> { .addValidator((subject, input, context) -> {
final ValidationResult.Builder result = new ValidationResult.Builder() final ValidationResult.Builder result = new ValidationResult.Builder()
@ -161,8 +166,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
private static final List<PropertyDescriptor> properties; private static final List<PropertyDescriptor> properties;
static { static {
final List<PropertyDescriptor> props = new ArrayList<>(); final List<PropertyDescriptor> props = new ArrayList<>(getAbstractPropertyDescriptors());
props.addAll(getAbstractPropertyDescriptors());
props.add(WS_URI); props.add(WS_URI);
props.add(SSL_CONTEXT); props.add(SSL_CONTEXT);
props.add(CONNECTION_TIMEOUT); props.add(CONNECTION_TIMEOUT);
@ -176,12 +180,14 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
properties = Collections.unmodifiableList(props); properties = Collections.unmodifiableList(props);
} }
private final Map<String, SessionInfo> activeSessions = new ConcurrentHashMap<>();
private final ReentrantLock connectionLock = new ReentrantLock();
private WebSocketClient client; private WebSocketClient client;
private URI webSocketUri; private URI webSocketUri;
private String authorizationHeader; private String authorizationHeader;
private long connectionTimeoutMillis; private long connectionTimeoutMillis;
private volatile ScheduledExecutorService sessionMaintenanceScheduler; private volatile ScheduledExecutorService sessionMaintenanceScheduler;
private final ReentrantLock connectionLock = new ReentrantLock(); private ConfigurationContext configurationContext;
@Override @Override
protected List<PropertyDescriptor> getSupportedPropertyDescriptors() { protected List<PropertyDescriptor> getSupportedPropertyDescriptors() {
@ -191,7 +197,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
@OnEnabled @OnEnabled
@Override @Override
public void startClient(final ConfigurationContext context) throws Exception { public void startClient(final ConfigurationContext context) throws Exception {
configurationContext = context;
final SSLContextService sslService = context.getProperty(SSL_CONTEXT).asControllerService(SSLContextService.class); final SSLContextService sslService = context.getProperty(SSL_CONTEXT).asControllerService(SSLContextService.class);
SslContextFactory sslContextFactory = null; SslContextFactory sslContextFactory = null;
if (sslService != null) { if (sslService != null) {
@ -227,8 +233,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
client.start(); client.start();
activeSessions.clear(); activeSessions.clear();
webSocketUri = new URI(context.getProperty(WS_URI).evaluateAttributeExpressions(new HashMap<>()).getValue());
webSocketUri = new URI(context.getProperty(WS_URI).evaluateAttributeExpressions().getValue());
connectionTimeoutMillis = context.getProperty(CONNECTION_TIMEOUT).evaluateAttributeExpressions().asTimePeriod(TimeUnit.MILLISECONDS); connectionTimeoutMillis = context.getProperty(CONNECTION_TIMEOUT).evaluateAttributeExpressions().asTimePeriod(TimeUnit.MILLISECONDS);
final Long sessionMaintenanceInterval = context.getProperty(SESSION_MAINTENANCE_INTERVAL).evaluateAttributeExpressions().asTimePeriod(TimeUnit.MILLISECONDS); final Long sessionMaintenanceInterval = context.getProperty(SESSION_MAINTENANCE_INTERVAL).evaluateAttributeExpressions().asTimePeriod(TimeUnit.MILLISECONDS);
@ -281,10 +286,20 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
@Override @Override
public void connect(final String clientId) throws IOException { public void connect(final String clientId) throws IOException {
connect(clientId, null); connect(clientId, null, Collections.emptyMap());
} }
private void connect(final String clientId, String sessionId) throws IOException { @Override
public void connect(final String clientId, final Map<String, String> flowFileAttributes) throws IOException {
connect(clientId, null, flowFileAttributes);
}
private void connect(final String clientId, final String sessionId, final Map<String, String> flowFileAttributes) throws IOException {
try {
webSocketUri = new URI(configurationContext.getProperty(WS_URI).evaluateAttributeExpressions(flowFileAttributes).getValue());
} catch (URISyntaxException e) {
throw new ProcessException("Could not create websocket URI", e);
}
connectionLock.lock(); connectionLock.lock();
@ -299,11 +314,15 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
listener.setSessionId(sessionId); listener.setSessionId(sessionId);
final ClientUpgradeRequest request = new ClientUpgradeRequest(); final ClientUpgradeRequest request = new ClientUpgradeRequest();
if (!flowFileAttributes.isEmpty()) {
request.setHeaders(HeaderMapExtractor.getHeaderMap(flowFileAttributes));
}
if (!StringUtils.isEmpty(authorizationHeader)) { if (!StringUtils.isEmpty(authorizationHeader)) {
request.setHeader(HttpHeader.AUTHORIZATION.asString(), authorizationHeader); request.setHeader(HttpHeader.AUTHORIZATION.asString(), authorizationHeader);
} }
final Future<Session> connect = client.connect(listener, webSocketUri, request); final Future<Session> connect = client.connect(listener, webSocketUri, request);
getLogger().info("Connecting to : {}", new Object[]{webSocketUri}); getLogger().info("Connecting to : {}", webSocketUri);
final Session session; final Session session;
try { try {
@ -311,8 +330,8 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
} catch (Exception e) { } catch (Exception e) {
throw new IOException("Failed to connect " + webSocketUri + " due to: " + e, e); throw new IOException("Failed to connect " + webSocketUri + " due to: " + e, e);
} }
getLogger().info("Connected, session={}", new Object[]{session}); getLogger().info("Connected, session={}", session);
activeSessions.put(clientId, listener.getSessionId()); activeSessions.put(clientId, new SessionInfo(listener.getSessionId(), flowFileAttributes));
} finally { } finally {
connectionLock.unlock(); connectionLock.unlock();
@ -320,8 +339,6 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
} }
private Map<String, String> activeSessions = new ConcurrentHashMap<>();
void maintainSessions() throws Exception { void maintainSessions() throws Exception {
if (client == null) { if (client == null) {
return; return;
@ -338,19 +355,19 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
router = routers.getRouterOrFail(clientId); router = routers.getRouterOrFail(clientId);
} catch (final WebSocketConfigurationException e) { } catch (final WebSocketConfigurationException e) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("The clientId {} is no longer active. Discarding the clientId.", new Object[]{clientId}); logger.debug("The clientId {} is no longer active. Discarding the clientId.", clientId);
} }
activeSessions.remove(clientId); activeSessions.remove(clientId);
continue; continue;
} }
final String sessionId = activeSessions.get(clientId); final SessionInfo sessionInfo = activeSessions.get(clientId);
// If this session is still alive, do nothing. // If this session is still alive, do nothing.
if (!router.containsSession(sessionId)) { if (!router.containsSession(sessionInfo.getSessionId())) {
// This session is no longer active, reconnect it. // This session is no longer active, reconnect it.
// If it fails, the sessionId will remain in activeSessions, and retries later. // If it fails, the sessionId will remain in activeSessions, and retries later.
// This reconnect attempt is continued until user explicitly stops a processor or this controller service. // This reconnect attempt is continued until user explicitly stops a processor or this controller service.
connect(clientId, sessionId); connect(clientId, sessionInfo.getSessionId(), sessionInfo.getFlowFileAttributes());
} }
} }
} finally { } finally {
@ -358,7 +375,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
} }
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Session maintenance completed. activeSessions={}", new Object[]{activeSessions}); logger.debug("Session maintenance completed. activeSessions={}", activeSessions);
} }
} }

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package util;
import org.apache.nifi.util.StringUtils;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public final class HeaderMapExtractor {
private HeaderMapExtractor(){
// Utility class, not meant to be instantiated.
}
public static final String HEADER_PREFIX = "header.";
public static Map<String, List<String>> getHeaderMap(final Map<String, String> flowFileAttributes) {
return flowFileAttributes.entrySet().stream()
.filter(entry -> entry.getKey().startsWith(HEADER_PREFIX))
.filter(entry -> StringUtils.isNotBlank(entry.getValue()))
.map(entry -> new AbstractMap.SimpleImmutableEntry<>(StringUtils.substringAfter(entry.getKey(), HEADER_PREFIX), entry.getValue()))
.collect(Collectors.toMap(Map.Entry::getKey, HeaderMapExtractor::headerValueMapper));
}
private static List<String> headerValueMapper(Map.Entry<String, String> entry) {
return Arrays.stream(entry.getValue().split(",")).map(String::trim).collect(Collectors.toList());
}
}

View File

@ -30,6 +30,7 @@ import org.mockito.invocation.InvocationOnMock;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -234,7 +235,7 @@ public class ITJettyWebSocketCommunication {
clientService.registerProcessor(clientId, clientProcessor); clientService.registerProcessor(clientId, clientProcessor);
clientService.connect(clientId); clientService.connect(clientId, Collections.emptyMap());
assertTrue("WebSocket client should be able to fire connected event.", clientConnectedServer.await(5, TimeUnit.SECONDS)); assertTrue("WebSocket client should be able to fire connected event.", clientConnectedServer.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket server should be able to fire connected event.", serverIsConnectedByClient.await(5, TimeUnit.SECONDS)); assertTrue("WebSocket server should be able to fire connected event.", serverIsConnectedByClient.await(5, TimeUnit.SECONDS));

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.websocket.util;
import org.junit.Test;
import util.HeaderMapExtractor;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class HeaderMapExtractorTest {
@Test
public void testMapExtractor() {
// GIVEN
final Map<String, String> attributes = new HashMap<>();
attributes.put("header.AUTHORIZATION", "AUTH_VALUE");
attributes.put("header.MULTI_HEADER_KEY", "FIRST, SECOND ,THIRD");
attributes.put("header.dots.dots.dots.DOTS_KEY", "DOTS_VALUE");
attributes.put("something.else.SOMETHING_ELSE_KEY", "SOMETHING_ELSE_VALUE");
attributes.put("header.EMPTY_VALUE_KEY", "");
attributes.put("headerButNotReally.UNRECOGNIZED", "NOT_A_HEADER_VALUE");
final Map<String, List<String>> expected = new HashMap<>();
expected.put("AUTHORIZATION", Collections.singletonList("AUTH_VALUE"));
expected.put("MULTI_HEADER_KEY", Arrays.asList("FIRST", "SECOND", "THIRD"));
expected.put("dots.dots.dots.DOTS_KEY", Collections.singletonList("DOTS_VALUE"));
// WHEN
final Map<String, List<String>> actual = HeaderMapExtractor.getHeaderMap(attributes);
// THEN
assertEquals(expected, actual);
assertEquals(expected.size(), actual.size());
for (Map.Entry<String, List<String>> entry : actual.entrySet()) {
assertTrue(expected.containsKey(entry.getKey()));
final List<String> actualHeaderValues = entry.getValue();
final List<String> expectedHeaderValues = expected.get(entry.getKey());
for (int i = 0; i < actualHeaderValues.size(); i++) {
assertEquals(expectedHeaderValues.get(i), actualHeaderValues.get(i));
}
}
}
}