diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/pom.xml b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/pom.xml index da7a77cc5e..50b2374502 100644 --- a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/pom.xml +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/pom.xml @@ -52,5 +52,11 @@ 1.15.0-SNAPSHOT test + + org.apache.nifi + nifi-websocket-services-jetty + 1.14.0-SNAPSHOT + test + diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/java/org/apache/nifi/processors/websocket/AbstractWebSocketGatewayProcessor.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/java/org/apache/nifi/processors/websocket/AbstractWebSocketGatewayProcessor.java index cc32c71f19..c516464eb2 100644 --- a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/java/org/apache/nifi/processors/websocket/AbstractWebSocketGatewayProcessor.java +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/java/org/apache/nifi/processors/websocket/AbstractWebSocketGatewayProcessor.java @@ -43,6 +43,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Objects; import java.util.Set; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID; @@ -89,6 +90,7 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF logger = getLogger(); } + @FunctionalInterface public interface WebSocketFunction { void execute(final WebSocketService webSocketService) throws IOException, WebSocketConfigurationException; } @@ -118,12 +120,24 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF // @OnScheduled can not report error messages well on bulletin since it's an async method. // So, let's do it in onTrigger(). - public void onWebSocketServiceReady(final WebSocketService webSocketService) throws IOException { - + public void onWebSocketServiceReady(final WebSocketService webSocketService, final ProcessContext context) throws IOException { if (webSocketService instanceof WebSocketClientService) { // If it's a ws client, then connect to the remote here. // Otherwise, ws server is already started at WebSocketServerService - ((WebSocketClientService) webSocketService).connect(endpointId); + WebSocketClientService webSocketClientService = (WebSocketClientService) webSocketService; + if (context.hasIncomingConnection()) { + final ProcessSession session = processSessionFactory.createSession(); + final FlowFile flowFile = session.get(); + final Map attributes = flowFile.getAttributes(); + try { + webSocketClientService.connect(endpointId, attributes); + } finally { + session.remove(flowFile); + session.commitAsync(); + } + } else { + webSocketClientService.connect(endpointId); + } } } @@ -137,6 +151,7 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF } protected abstract WebSocketService getWebSocketService(final ProcessContext context); + protected abstract String getEndpointId(final ProcessContext context); protected boolean isProcessorRegisteredToService() { @@ -146,7 +161,7 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF } @OnStopped - public void onStopped(final ProcessContext context) throws IOException { + public void onStopped(final ProcessContext context) { deregister(); } @@ -165,27 +180,36 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF } @Override - public final void onTrigger(final ProcessContext context, final ProcessSessionFactory sessionFactory) throws ProcessException { + public final void onTrigger(final ProcessContext context, final ProcessSessionFactory sessionFactory) { if (processSessionFactory == null) { processSessionFactory = sessionFactory; } if (!isProcessorRegisteredToService()) { try { - registerProcessorToService(context, webSocketService -> onWebSocketServiceReady(webSocketService)); - } catch (IOException|WebSocketConfigurationException e) { + registerProcessorToService(context, webSocketService -> onWebSocketServiceReady(webSocketService, context)); + } catch (IOException | WebSocketConfigurationException e) { // Deregister processor if it failed so that it can retry next onTrigger. deregister(); context.yield(); throw new ProcessException("Failed to register processor to WebSocket service due to: " + e, e); } + + } else { + try { + onWebSocketServiceReady(webSocketService, context); + } catch (IOException e) { + deregister(); + context.yield(); + throw new ProcessException("Failed to renew session and connect to WebSocket service due to: " + e, e); + } } context.yield();//nothing really to do here since handling WebSocket messages is done at ControllerService. } - private void enqueueMessage(final WebSocketMessage incomingMessage){ + private void enqueueMessage(final WebSocketMessage incomingMessage) { final ProcessSession session = processSessionFactory.createSession(); try { FlowFile messageFlowFile = session.create(); @@ -206,9 +230,9 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF final byte[] payload = incomingMessage.getPayload(); if (payload != null) { - messageFlowFile = session.write(messageFlowFile, out -> { - out.write(payload, incomingMessage.getOffset(), incomingMessage.getLength()); - }); + messageFlowFile = session.write(messageFlowFile, out -> + out.write(payload, incomingMessage.getOffset(), incomingMessage.getLength()) + ); } session.getProvenanceReporter().receive(messageFlowFile, getTransitUri(sessionInfo)); @@ -216,7 +240,7 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF if (incomingMessage instanceof WebSocketConnectedMessage) { session.transfer(messageFlowFile, REL_CONNECTED); } else { - switch (messageType) { + switch (Objects.requireNonNull(messageType)) { case TEXT: session.transfer(messageFlowFile, REL_MESSAGE_TEXT); break; @@ -233,6 +257,6 @@ public abstract class AbstractWebSocketGatewayProcessor extends AbstractSessionF } } - abstract protected String getTransitUri(final WebSocketSessionInfo sessionInfo); + protected abstract String getTransitUri(final WebSocketSessionInfo sessionInfo); } diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/java/org/apache/nifi/processors/websocket/ConnectWebSocket.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/java/org/apache/nifi/processors/websocket/ConnectWebSocket.java index 3688b7b790..dfbd9c482c 100644 --- a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/java/org/apache/nifi/processors/websocket/ConnectWebSocket.java +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/java/org/apache/nifi/processors/websocket/ConnectWebSocket.java @@ -44,7 +44,7 @@ import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes. import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_SESSION_ID; @Tags({"subscribe", "WebSocket", "consume", "listen"}) -@InputRequirement(InputRequirement.Requirement.INPUT_FORBIDDEN) +@InputRequirement(InputRequirement.Requirement.INPUT_ALLOWED) @TriggerSerially @CapabilityDescription("Acts as a WebSocket client endpoint to interact with a remote WebSocket server." + " FlowFiles are transferred to downstream relationships according to received message types" + @@ -80,7 +80,7 @@ public class ConnectWebSocket extends AbstractWebSocketGatewayProcessor { private static final List descriptors; private static final Set relationships; - static{ + static { final List innerDescriptorsList = new ArrayList<>(); innerDescriptorsList.add(PROP_WEBSOCKET_CLIENT_SERVICE); innerDescriptorsList.add(PROP_WEBSOCKET_CLIENT_ID); @@ -113,6 +113,6 @@ public class ConnectWebSocket extends AbstractWebSocketGatewayProcessor { @Override protected String getTransitUri(final WebSocketSessionInfo sessionInfo) { - return ((WebSocketClientService)webSocketService).getTargetUri(); + return ((WebSocketClientService) webSocketService).getTargetUri(); } } diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/resources/docs/org.apache.nifi.processors.websocket.ConnectWebSocket/additionalDetails.html b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/resources/docs/org.apache.nifi.processors.websocket.ConnectWebSocket/additionalDetails.html new file mode 100644 index 0000000000..8c4c2e8135 --- /dev/null +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/main/resources/docs/org.apache.nifi.processors.websocket.ConnectWebSocket/additionalDetails.html @@ -0,0 +1,47 @@ + + + + + + ConnectWebSocket + + + + +

Summary

+

+ This processor acts as a WebSocket client endpoint to interact with a remote WebSocket server. + It is capable of receiving messages from a websocket server and it transfers them to downstream relationships + according to the received message types. +

+

+ The processor may have an incoming relationship, in which case flowfile attributes are passed down to its WebSocket Client Service. + This can be used to fine-tune the connection configuration (url and headers for example). For example "dynamic.url = + currentValue" flowfile attribute can be referenced in the WebSocket Client Service with the ${dynamic.url} expression. +

+

+ You can define custom websocket headers in the incoming flowfile as additional attributes. The attribute key + shall start with "header." and continue with they header key. For example: "header.Authorization". The attribute + value will be the corresponding header value. +

    +
  1. header.Autorization | Basic base64UserNamePassWord
  2. +
  3. header.Content-Type | application, audio, example
  4. +
+

+ For multiple header values provide a comma separated list. +

+ + \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/test/java/org/apache/nifi/processors/websocket/TestConnectWebSocket.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/test/java/org/apache/nifi/processors/websocket/TestConnectWebSocket.java index df6bd6fcb1..304cafcf7d 100644 --- a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/test/java/org/apache/nifi/processors/websocket/TestConnectWebSocket.java +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-processors/src/test/java/org/apache/nifi/processors/websocket/TestConnectWebSocket.java @@ -20,6 +20,8 @@ import org.apache.nifi.processor.ProcessSessionFactory; import org.apache.nifi.processor.Relationship; import org.apache.nifi.provenance.ProvenanceEventRecord; import org.apache.nifi.provenance.ProvenanceEventType; +import org.apache.nifi.remote.io.socket.NetworkUtils; +import org.apache.nifi.reporting.InitializationException; import org.apache.nifi.util.MockFlowFile; import org.apache.nifi.util.MockProcessSession; import org.apache.nifi.util.SharedSessionState; @@ -29,9 +31,13 @@ import org.apache.nifi.websocket.AbstractWebSocketSession; import org.apache.nifi.websocket.WebSocketClientService; import org.apache.nifi.websocket.WebSocketMessage; import org.apache.nifi.websocket.WebSocketSession; +import org.apache.nifi.websocket.jetty.JettyWebSocketClient; +import org.apache.nifi.websocket.jetty.JettyWebSocketServer; +import org.junit.Assert; import org.junit.Test; import java.net.InetSocketAddress; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -51,7 +57,9 @@ public class TestConnectWebSocket extends TestListenWebSocket { @Test public void testSuccess() throws Exception { final TestRunner runner = TestRunners.newTestRunner(ConnectWebSocket.class); - final ConnectWebSocket processor = (ConnectWebSocket)runner.getProcessor(); + runner.setIncomingConnection(false); + + final ConnectWebSocket processor = (ConnectWebSocket) runner.getProcessor(); final SharedSessionState sharedSessionState = new SharedSessionState(processor, new AtomicLong(0)); // Use this custom session factory implementation so that createdSessions can be read from test case, @@ -121,4 +129,63 @@ public class TestConnectWebSocket extends TestListenWebSocket { assertTrue(provenanceEvents.stream().allMatch(event -> ProvenanceEventType.RECEIVE.equals(event.getEventType()))); } + @Test + public void testDynamicUrlsParsedFromFlowFileAndAbleToConnect() throws InitializationException { + // Start websocket server + final int port = NetworkUtils.availablePort(); + TestRunner webSocketListener = getListenWebSocket(port); + webSocketListener.run(1, false); + + final TestRunner runner = TestRunners.newTestRunner(ConnectWebSocket.class); + + final String serviceId = "ws-service"; + final String endpointId = "client-1"; + + Map attributes = new HashMap<>(); + attributes.put("dynamicUrlPart", "test"); + MockFlowFile flowFile = new MockFlowFile(1L); + flowFile.putAttributes(attributes); + runner.enqueue(flowFile); + + attributes.put("dynamicUrlPart", "test2"); + MockFlowFile flowFileWithWrongUrl = new MockFlowFile(2L); + flowFileWithWrongUrl.putAttributes(attributes); + runner.enqueue(flowFileWithWrongUrl); + + JettyWebSocketClient service = new JettyWebSocketClient(); + + + runner.addControllerService(serviceId, service); + runner.setProperty(service, JettyWebSocketClient.WS_URI, String.format("ws://localhost:%s/${dynamicUrlPart}", port)); + runner.enableControllerService(service); + + runner.setProperty(ConnectWebSocket.PROP_WEBSOCKET_CLIENT_SERVICE, serviceId); + runner.setProperty(ConnectWebSocket.PROP_WEBSOCKET_CLIENT_ID, endpointId); + + runner.run(1, false); + + final List flowFilesForRelationship = runner.getFlowFilesForRelationship(ConnectWebSocket.REL_CONNECTED); + assertEquals(1, flowFilesForRelationship.size()); + + final AssertionError assertionError = Assert.assertThrows(AssertionError.class, () -> runner.run(1)); + assertTrue(assertionError.getCause().getLocalizedMessage().contains("Failed to renew session and connect to WebSocket service")); + + runner.stop(); + webSocketListener.stop(); + } + + private TestRunner getListenWebSocket(final int port) throws InitializationException { + final TestRunner runner = TestRunners.newTestRunner(ListenWebSocket.class); + + final String serviceId = "ws-server-service"; + JettyWebSocketServer service = new JettyWebSocketServer(); + runner.addControllerService(serviceId, service); + runner.setProperty(service, JettyWebSocketServer.LISTEN_PORT, String.valueOf(port)); + runner.enableControllerService(service); + + runner.setProperty(ListenWebSocket.PROP_WEBSOCKET_SERVER_SERVICE, serviceId); + runner.setProperty(ListenWebSocket.PROP_SERVER_URL_PATH, "/test"); + + return runner; + } } diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-api/src/main/java/org/apache/nifi/websocket/WebSocketClientService.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-api/src/main/java/org/apache/nifi/websocket/WebSocketClientService.java index 9d4ef163fe..c5089b74d8 100644 --- a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-api/src/main/java/org/apache/nifi/websocket/WebSocketClientService.java +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-api/src/main/java/org/apache/nifi/websocket/WebSocketClientService.java @@ -20,6 +20,7 @@ import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.controller.ConfigurationContext; import java.io.IOException; +import java.util.Map; /** * Control a WebSocket client instance. @@ -34,6 +35,10 @@ public interface WebSocketClientService extends WebSocketService { void connect(final String clientId) throws IOException; + default void connect(final String clientId, final Map flowFileAttributes) throws IOException { + connect(clientId); + } + String getTargetUri(); } diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/dto/SessionInfo.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/dto/SessionInfo.java new file mode 100644 index 0000000000..7caa2f5ff9 --- /dev/null +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/dto/SessionInfo.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dto; + +import java.util.Map; + +public class SessionInfo { + + private final String sessionId; + private final Map flowFileAttributes; + + public SessionInfo(final String sessionId, final Map flowFileAttributes) { + this.sessionId = sessionId; + this.flowFileAttributes = flowFileAttributes; + } + + public String getSessionId() { + return sessionId; + } + + public Map getFlowFileAttributes() { + return flowFileAttributes; + } +} diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/org/apache/nifi/websocket/jetty/JettyWebSocketClient.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/org/apache/nifi/websocket/jetty/JettyWebSocketClient.java index 84e84f3044..60bdbf9999 100644 --- a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/org/apache/nifi/websocket/jetty/JettyWebSocketClient.java +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/org/apache/nifi/websocket/jetty/JettyWebSocketClient.java @@ -16,6 +16,7 @@ */ package org.apache.nifi.websocket.jetty; +import dto.SessionInfo; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.lifecycle.OnDisabled; @@ -27,6 +28,7 @@ import org.apache.nifi.components.ValidationResult; import org.apache.nifi.controller.ConfigurationContext; import org.apache.nifi.expression.ExpressionLanguageScope; import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.ssl.SSLContextService; import org.apache.nifi.util.StringUtils; @@ -40,14 +42,17 @@ import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; +import util.HeaderMapExtractor; import java.io.IOException; import java.net.URI; +import java.net.URISyntaxException; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Base64; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -68,7 +73,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen .displayName("WebSocket URI") .description("The WebSocket URI this client connects to.") .required(true) - .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY) + .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES) .addValidator(StandardValidators.URI_VALIDATOR) .addValidator((subject, input, context) -> { final ValidationResult.Builder result = new ValidationResult.Builder() @@ -79,7 +84,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen result.explanation("Expression Language Present").valid(true); } else { result.explanation("Protocol should be either 'ws' or 'wss'.") - .valid(input.startsWith("ws://") || input.startsWith("wss://")); + .valid(input.startsWith("ws://") || input.startsWith("wss://")); } return result.build(); @@ -161,8 +166,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen private static final List properties; static { - final List props = new ArrayList<>(); - props.addAll(getAbstractPropertyDescriptors()); + final List props = new ArrayList<>(getAbstractPropertyDescriptors()); props.add(WS_URI); props.add(SSL_CONTEXT); props.add(CONNECTION_TIMEOUT); @@ -176,12 +180,14 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen properties = Collections.unmodifiableList(props); } + private final Map activeSessions = new ConcurrentHashMap<>(); + private final ReentrantLock connectionLock = new ReentrantLock(); private WebSocketClient client; private URI webSocketUri; private String authorizationHeader; private long connectionTimeoutMillis; private volatile ScheduledExecutorService sessionMaintenanceScheduler; - private final ReentrantLock connectionLock = new ReentrantLock(); + private ConfigurationContext configurationContext; @Override protected List getSupportedPropertyDescriptors() { @@ -190,8 +196,8 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen @OnEnabled @Override - public void startClient(final ConfigurationContext context) throws Exception{ - + public void startClient(final ConfigurationContext context) throws Exception { + configurationContext = context; final SSLContextService sslService = context.getProperty(SSL_CONTEXT).asControllerService(SSLContextService.class); SslContextFactory sslContextFactory = null; if (sslService != null) { @@ -227,8 +233,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen client.start(); activeSessions.clear(); - - webSocketUri = new URI(context.getProperty(WS_URI).evaluateAttributeExpressions().getValue()); + webSocketUri = new URI(context.getProperty(WS_URI).evaluateAttributeExpressions(new HashMap<>()).getValue()); connectionTimeoutMillis = context.getProperty(CONNECTION_TIMEOUT).evaluateAttributeExpressions().asTimePeriod(TimeUnit.MILLISECONDS); final Long sessionMaintenanceInterval = context.getProperty(SESSION_MAINTENANCE_INTERVAL).evaluateAttributeExpressions().asTimePeriod(TimeUnit.MILLISECONDS); @@ -281,10 +286,20 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen @Override public void connect(final String clientId) throws IOException { - connect(clientId, null); + connect(clientId, null, Collections.emptyMap()); } - private void connect(final String clientId, String sessionId) throws IOException { + @Override + public void connect(final String clientId, final Map flowFileAttributes) throws IOException { + connect(clientId, null, flowFileAttributes); + } + + private void connect(final String clientId, final String sessionId, final Map flowFileAttributes) throws IOException { + try { + webSocketUri = new URI(configurationContext.getProperty(WS_URI).evaluateAttributeExpressions(flowFileAttributes).getValue()); + } catch (URISyntaxException e) { + throw new ProcessException("Could not create websocket URI", e); + } connectionLock.lock(); @@ -293,17 +308,21 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen try { router = routers.getRouterOrFail(clientId); } catch (WebSocketConfigurationException e) { - throw new IllegalStateException("Failed to get router due to: " + e, e); + throw new IllegalStateException("Failed to get router due to: " + e, e); } final RoutingWebSocketListener listener = new RoutingWebSocketListener(router); listener.setSessionId(sessionId); final ClientUpgradeRequest request = new ClientUpgradeRequest(); + + if (!flowFileAttributes.isEmpty()) { + request.setHeaders(HeaderMapExtractor.getHeaderMap(flowFileAttributes)); + } if (!StringUtils.isEmpty(authorizationHeader)) { request.setHeader(HttpHeader.AUTHORIZATION.asString(), authorizationHeader); } final Future connect = client.connect(listener, webSocketUri, request); - getLogger().info("Connecting to : {}", new Object[]{webSocketUri}); + getLogger().info("Connecting to : {}", webSocketUri); final Session session; try { @@ -311,8 +330,8 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen } catch (Exception e) { throw new IOException("Failed to connect " + webSocketUri + " due to: " + e, e); } - getLogger().info("Connected, session={}", new Object[]{session}); - activeSessions.put(clientId, listener.getSessionId()); + getLogger().info("Connected, session={}", session); + activeSessions.put(clientId, new SessionInfo(listener.getSessionId(), flowFileAttributes)); } finally { connectionLock.unlock(); @@ -320,8 +339,6 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen } - private Map activeSessions = new ConcurrentHashMap<>(); - void maintainSessions() throws Exception { if (client == null) { return; @@ -338,19 +355,19 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen router = routers.getRouterOrFail(clientId); } catch (final WebSocketConfigurationException e) { if (logger.isDebugEnabled()) { - logger.debug("The clientId {} is no longer active. Discarding the clientId.", new Object[]{clientId}); + logger.debug("The clientId {} is no longer active. Discarding the clientId.", clientId); } activeSessions.remove(clientId); continue; } - final String sessionId = activeSessions.get(clientId); + final SessionInfo sessionInfo = activeSessions.get(clientId); // If this session is still alive, do nothing. - if (!router.containsSession(sessionId)) { + if (!router.containsSession(sessionInfo.getSessionId())) { // This session is no longer active, reconnect it. // If it fails, the sessionId will remain in activeSessions, and retries later. // This reconnect attempt is continued until user explicitly stops a processor or this controller service. - connect(clientId, sessionId); + connect(clientId, sessionInfo.getSessionId(), sessionInfo.getFlowFileAttributes()); } } } finally { @@ -358,7 +375,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen } if (logger.isDebugEnabled()) { - logger.debug("Session maintenance completed. activeSessions={}", new Object[]{activeSessions}); + logger.debug("Session maintenance completed. activeSessions={}", activeSessions); } } diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/util/HeaderMapExtractor.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/util/HeaderMapExtractor.java new file mode 100644 index 0000000000..f29db9f416 --- /dev/null +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/util/HeaderMapExtractor.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package util; + +import org.apache.nifi.util.StringUtils; + +import java.util.AbstractMap; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public final class HeaderMapExtractor { + + private HeaderMapExtractor(){ + // Utility class, not meant to be instantiated. + } + + public static final String HEADER_PREFIX = "header."; + + public static Map> getHeaderMap(final Map flowFileAttributes) { + return flowFileAttributes.entrySet().stream() + .filter(entry -> entry.getKey().startsWith(HEADER_PREFIX)) + .filter(entry -> StringUtils.isNotBlank(entry.getValue())) + .map(entry -> new AbstractMap.SimpleImmutableEntry<>(StringUtils.substringAfter(entry.getKey(), HEADER_PREFIX), entry.getValue())) + .collect(Collectors.toMap(Map.Entry::getKey, HeaderMapExtractor::headerValueMapper)); + } + + private static List headerValueMapper(Map.Entry entry) { + return Arrays.stream(entry.getValue().split(",")).map(String::trim).collect(Collectors.toList()); + } + +} diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/jetty/ITJettyWebSocketCommunication.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/jetty/ITJettyWebSocketCommunication.java index a67f489fda..51f12f3ab7 100644 --- a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/jetty/ITJettyWebSocketCommunication.java +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/jetty/ITJettyWebSocketCommunication.java @@ -30,6 +30,7 @@ import org.mockito.invocation.InvocationOnMock; import java.net.ServerSocket; import java.nio.ByteBuffer; +import java.util.Collections; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -234,7 +235,7 @@ public class ITJettyWebSocketCommunication { clientService.registerProcessor(clientId, clientProcessor); - clientService.connect(clientId); + clientService.connect(clientId, Collections.emptyMap()); assertTrue("WebSocket client should be able to fire connected event.", clientConnectedServer.await(5, TimeUnit.SECONDS)); assertTrue("WebSocket server should be able to fire connected event.", serverIsConnectedByClient.await(5, TimeUnit.SECONDS)); diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/util/HeaderMapExtractorTest.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/util/HeaderMapExtractorTest.java new file mode 100644 index 0000000000..dec8a36d65 --- /dev/null +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/util/HeaderMapExtractorTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.websocket.util; + +import org.junit.Test; +import util.HeaderMapExtractor; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class HeaderMapExtractorTest { + + @Test + public void testMapExtractor() { + + // GIVEN + final Map attributes = new HashMap<>(); + attributes.put("header.AUTHORIZATION", "AUTH_VALUE"); + attributes.put("header.MULTI_HEADER_KEY", "FIRST, SECOND ,THIRD"); + attributes.put("header.dots.dots.dots.DOTS_KEY", "DOTS_VALUE"); + attributes.put("something.else.SOMETHING_ELSE_KEY", "SOMETHING_ELSE_VALUE"); + attributes.put("header.EMPTY_VALUE_KEY", ""); + attributes.put("headerButNotReally.UNRECOGNIZED", "NOT_A_HEADER_VALUE"); + + final Map> expected = new HashMap<>(); + expected.put("AUTHORIZATION", Collections.singletonList("AUTH_VALUE")); + expected.put("MULTI_HEADER_KEY", Arrays.asList("FIRST", "SECOND", "THIRD")); + expected.put("dots.dots.dots.DOTS_KEY", Collections.singletonList("DOTS_VALUE")); + + // WHEN + final Map> actual = HeaderMapExtractor.getHeaderMap(attributes); + + // THEN + assertEquals(expected, actual); + + assertEquals(expected.size(), actual.size()); + for (Map.Entry> entry : actual.entrySet()) { + assertTrue(expected.containsKey(entry.getKey())); + final List actualHeaderValues = entry.getValue(); + final List expectedHeaderValues = expected.get(entry.getKey()); + for (int i = 0; i < actualHeaderValues.size(); i++) { + assertEquals(expectedHeaderValues.get(i), actualHeaderValues.get(i)); + } + } + } + +}