Support for websocket multiplexing to all existing websocket

connections (Think chat to all clients instead of individual person).
The core change was a change in WebSocketMessageRouter.java where if a
sessionId is not present the message is sent to all connected clients.
So the key is leaving the sessionId to empty or null to send to all
clients. If the sessionId is specified the message will be sent just to
that session specified.

This closes #1649.

Signed-off-by: Koji Kawamura <ijokarumawak@apache.org>
This commit is contained in:
Jeremy Dyer 2017-04-04 14:46:17 -04:00 committed by Koji Kawamura
parent 816034bd01
commit 769e874677
3 changed files with 81 additions and 58 deletions

View File

@ -17,6 +17,26 @@
package org.apache.nifi.processors.websocket; package org.apache.nifi.processors.websocket;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_ENDPOINT_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_FAILURE_DETAIL;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_LOCAL_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_MESSAGE_TYPE;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_REMOTE_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_SESSION_ID;
import static org.apache.nifi.websocket.WebSocketMessage.CHARSET_NAME;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.annotation.behavior.InputRequirement; import org.apache.nifi.annotation.behavior.InputRequirement;
import org.apache.nifi.annotation.behavior.TriggerSerially; import org.apache.nifi.annotation.behavior.TriggerSerially;
@ -38,25 +58,6 @@ import org.apache.nifi.websocket.WebSocketConfigurationException;
import org.apache.nifi.websocket.WebSocketMessage; import org.apache.nifi.websocket.WebSocketMessage;
import org.apache.nifi.websocket.WebSocketService; import org.apache.nifi.websocket.WebSocketService;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_ENDPOINT_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_FAILURE_DETAIL;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_LOCAL_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_MESSAGE_TYPE;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_REMOTE_ADDRESS;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_SESSION_ID;
import static org.apache.nifi.websocket.WebSocketMessage.CHARSET_NAME;
@Tags({"WebSocket", "publish", "send"}) @Tags({"WebSocket", "publish", "send"})
@InputRequirement(InputRequirement.Requirement.INPUT_REQUIRED) @InputRequirement(InputRequirement.Requirement.INPUT_REQUIRED)
@TriggerSerially @TriggerSerially
@ -76,7 +77,8 @@ public class PutWebSocket extends AbstractProcessor {
public static final PropertyDescriptor PROP_WS_SESSION_ID = new PropertyDescriptor.Builder() public static final PropertyDescriptor PROP_WS_SESSION_ID = new PropertyDescriptor.Builder()
.name("websocket-session-id") .name("websocket-session-id")
.displayName("WebSocket Session Id") .displayName("WebSocket Session Id")
.description("A NiFi Expression to retrieve the session id.") .description("A NiFi Expression to retrieve the session id. If not specified, a message will be " +
"sent to all connected WebSocket peers for the WebSocket controller service endpoint.")
.required(true) .required(true)
.addValidator(StandardValidators.NON_BLANK_VALIDATOR) .addValidator(StandardValidators.NON_BLANK_VALIDATOR)
.expressionLanguageSupported(true) .expressionLanguageSupported(true)
@ -166,8 +168,11 @@ public class PutWebSocket extends AbstractProcessor {
.evaluateAttributeExpressions(flowfile).getValue(); .evaluateAttributeExpressions(flowfile).getValue();
final WebSocketMessage.Type messageType = WebSocketMessage.Type.valueOf(messageTypeStr); final WebSocketMessage.Type messageType = WebSocketMessage.Type.valueOf(messageTypeStr);
if (StringUtils.isEmpty(sessionId) if (StringUtils.isEmpty(sessionId)) {
|| StringUtils.isEmpty(webSocketServiceId) getLogger().debug("Specific SessionID not specified. Message will be broadcast to all connected clients.");
}
if (StringUtils.isEmpty(webSocketServiceId)
|| StringUtils.isEmpty(webSocketServiceEndpoint)) { || StringUtils.isEmpty(webSocketServiceEndpoint)) {
transferToFailure(processSession, flowfile, "Required WebSocket attribute was not found."); transferToFailure(processSession, flowfile, "Required WebSocket attribute was not found.");
return; return;
@ -187,9 +192,14 @@ public class PutWebSocket extends AbstractProcessor {
final byte[] messageContent = new byte[(int) flowfile.getSize()]; final byte[] messageContent = new byte[(int) flowfile.getSize()];
final long startSending = System.currentTimeMillis(); final long startSending = System.currentTimeMillis();
final AtomicReference<String> transitUri = new AtomicReference<>();
final Map<String, String> attrs = new HashMap<>(); final Map<String, String> attrs = new HashMap<>();
attrs.put(ATTR_WS_CS_ID, webSocketService.getIdentifier()); attrs.put(ATTR_WS_CS_ID, webSocketService.getIdentifier());
if (!StringUtils.isEmpty(sessionId)) {
attrs.put(ATTR_WS_SESSION_ID, sessionId); attrs.put(ATTR_WS_SESSION_ID, sessionId);
}
attrs.put(ATTR_WS_ENDPOINT_ID, webSocketServiceEndpoint); attrs.put(ATTR_WS_ENDPOINT_ID, webSocketServiceEndpoint);
attrs.put(ATTR_WS_MESSAGE_TYPE, messageTypeStr); attrs.put(ATTR_WS_MESSAGE_TYPE, messageTypeStr);
@ -211,13 +221,14 @@ public class PutWebSocket extends AbstractProcessor {
attrs.put(ATTR_WS_LOCAL_ADDRESS, sender.getLocalAddress().toString()); attrs.put(ATTR_WS_LOCAL_ADDRESS, sender.getLocalAddress().toString());
attrs.put(ATTR_WS_REMOTE_ADDRESS, sender.getRemoteAddress().toString()); attrs.put(ATTR_WS_REMOTE_ADDRESS, sender.getRemoteAddress().toString());
transitUri.set(sender.getTransitUri());
});
final FlowFile updatedFlowFile = processSession.putAllAttributes(flowfile, attrs); final FlowFile updatedFlowFile = processSession.putAllAttributes(flowfile, attrs);
final long transmissionMillis = System.currentTimeMillis() - startSending; final long transmissionMillis = System.currentTimeMillis() - startSending;
processSession.getProvenanceReporter().send(updatedFlowFile, sender.getTransitUri(), transmissionMillis); processSession.getProvenanceReporter().send(updatedFlowFile, transitUri.get(), transmissionMillis);
processSession.transfer(updatedFlowFile, REL_SUCCESS); processSession.transfer(updatedFlowFile, REL_SUCCESS);
});
} catch (WebSocketConfigurationException|IllegalStateException|IOException e) { } catch (WebSocketConfigurationException|IllegalStateException|IOException e) {
// WebSocketConfigurationException: If the corresponding WebSocketGatewayProcessor has been stopped. // WebSocketConfigurationException: If the corresponding WebSocketGatewayProcessor has been stopped.
@ -235,5 +246,4 @@ public class PutWebSocket extends AbstractProcessor {
return flowfile; return flowfile;
} }
} }

View File

@ -16,24 +16,6 @@
*/ */
package org.apache.nifi.processors.websocket; package org.apache.nifi.processors.websocket;
import org.apache.nifi.controller.ControllerService;
import org.apache.nifi.provenance.ProvenanceEventRecord;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.apache.nifi.websocket.AbstractWebSocketSession;
import org.apache.nifi.websocket.SendMessage;
import org.apache.nifi.websocket.WebSocketMessage;
import org.apache.nifi.websocket.WebSocketService;
import org.apache.nifi.websocket.WebSocketSession;
import org.junit.Test;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_CS_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_ENDPOINT_ID; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_ENDPOINT_ID;
import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_FAILURE_DETAIL; import static org.apache.nifi.processors.websocket.WebSocketProcessorAttributes.ATTR_WS_FAILURE_DETAIL;
@ -50,6 +32,24 @@ import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.nifi.controller.ControllerService;
import org.apache.nifi.provenance.ProvenanceEventRecord;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.apache.nifi.websocket.AbstractWebSocketSession;
import org.apache.nifi.websocket.SendMessage;
import org.apache.nifi.websocket.WebSocketMessage;
import org.apache.nifi.websocket.WebSocketService;
import org.apache.nifi.websocket.WebSocketSession;
import org.junit.Test;
public class TestPutWebSocket { public class TestPutWebSocket {
@ -92,12 +92,12 @@ public class TestPutWebSocket {
runner.run(); runner.run();
final List<MockFlowFile> succeededFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_SUCCESS); final List<MockFlowFile> succeededFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_SUCCESS);
assertEquals(0, succeededFlowFiles.size()); //assertEquals(0, succeededFlowFiles.size()); //No longer valid test after NIFI-3318 since not specifying sessionid will send to all clients
assertEquals(1, succeededFlowFiles.size());
final List<MockFlowFile> failedFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_FAILURE); final List<MockFlowFile> failedFlowFiles = runner.getFlowFilesForRelationship(PutWebSocket.REL_FAILURE);
assertEquals(1, failedFlowFiles.size()); //assertEquals(1, failedFlowFiles.size()); //No longer valid test after NIFI-3318
final MockFlowFile failedFlowFile = failedFlowFiles.iterator().next(); assertEquals(0, failedFlowFiles.size());
assertNotNull(failedFlowFile.getAttribute(ATTR_WS_FAILURE_DETAIL));
final List<ProvenanceEventRecord> provenanceEvents = runner.getProvenanceEvents(); final List<ProvenanceEventRecord> provenanceEvents = runner.getProvenanceEvents();
assertEquals(0, provenanceEvents.size()); assertEquals(0, provenanceEvents.size());

View File

@ -16,14 +16,15 @@
*/ */
package org.apache.nifi.websocket; package org.apache.nifi.websocket;
import org.apache.nifi.processor.Processor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.processor.Processor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class WebSocketMessageRouter { public class WebSocketMessageRouter {
private static final Logger logger = LoggerFactory.getLogger(WebSocketMessageRouter.class); private static final Logger logger = LoggerFactory.getLogger(WebSocketMessageRouter.class);
private final String endpointId; private final String endpointId;
@ -101,8 +102,20 @@ public class WebSocketMessageRouter {
} }
public void sendMessage(final String sessionId, final SendMessage sendMessage) throws IOException { public void sendMessage(final String sessionId, final SendMessage sendMessage) throws IOException {
if (!StringUtils.isEmpty(sessionId)) {
final WebSocketSession session = getSessionOrFail(sessionId); final WebSocketSession session = getSessionOrFail(sessionId);
sendMessage.send(session); sendMessage.send(session);
} else {
//The sessionID is not specified so broadcast the message to all connected client sessions.
sessions.keySet().forEach(itrSessionId -> {
try {
final WebSocketSession session = getSessionOrFail(itrSessionId);
sendMessage.send(session);
} catch (IOException e) {
logger.warn("Failed to send message to session {} due to {}", itrSessionId, e, e);
}
});
}
} }
public void disconnect(final String sessionId, final String reason) throws IOException { public void disconnect(final String sessionId, final String reason) throws IOException {