NIFI-3609: ConnectWebSocket auto session recovery

- Removed unused disconnect method from WebSocketService interface.
- Added session maintenance background thread at JettyWebSocketClient
  which reconnects sessions those are still referred by ConnectWebSocket
  processor but no longer active.
- Added Session Maintenance Interval property to JettyWebSocketClient.
- Allowed specifying existing session id so that it can be recovered
  transparently.
- Moved test classes to appropriate package.
- Added test cases that verify the same session id can be used after
  WebSocket server restarts.
This commit is contained in:
Koji Kawamura 2017-03-16 17:10:38 +09:00 committed by Jeremy Dyer
parent 8fa35294eb
commit 0a014b471b
11 changed files with 218 additions and 36 deletions

View File

@ -45,9 +45,4 @@ public abstract class AbstractWebSocketService extends AbstractControllerService
routers.sendMessage(endpointId, sessionId, sendMessage);
}
@Override
public void disconnect(final String endpointId, final String sessionId, final String reason) throws IOException, WebSocketConfigurationException {
routers.disconnect(endpointId, sessionId, reason);
}
}

View File

@ -124,4 +124,8 @@ public class WebSocketMessageRouter {
sessions.remove(sessionId);
}
public boolean containsSession(final String sessionId) {
return sessions.containsKey(sessionId);
}
}

View File

@ -59,6 +59,7 @@ public class WebSocketMessageRouters {
public synchronized void deregisterProcessor(final String endpointId, final Processor processor) throws WebSocketConfigurationException {
final WebSocketMessageRouter router = getRouterOrFail(endpointId);
routers.remove(endpointId);
router.deregisterProcessor(processor);
}
@ -67,9 +68,4 @@ public class WebSocketMessageRouters {
router.sendMessage(sessionId, sendMessage);
}
public void disconnect(final String endpointId, final String sessionId, final String reason) throws IOException, WebSocketConfigurationException {
final WebSocketMessageRouter router = getRouterOrFail(endpointId);
router.disconnect(sessionId, reason);
}
}

View File

@ -45,6 +45,4 @@ public interface WebSocketService extends ControllerService {
void sendMessage(final String endpointId, final String sessionId, final SendMessage sendMessage) throws IOException, WebSocketConfigurationException;
void disconnect(final String endpointId, final String sessionId, final String reason) throws Exception;
}

View File

@ -24,6 +24,7 @@ import org.apache.nifi.annotation.lifecycle.OnShutdown;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.ssl.SSLContextService;
import org.apache.nifi.websocket.WebSocketClientService;
@ -39,8 +40,13 @@ import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
@Tags({"WebSocket", "Jetty", "client"})
@CapabilityDescription("Implementation of WebSocketClientService." +
@ -81,6 +87,16 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
.defaultValue("3 sec")
.build();
public static final PropertyDescriptor SESSION_MAINTENANCE_INTERVAL = new PropertyDescriptor.Builder()
.name("session-maintenance-interval")
.displayName("Session Maintenance Interval")
.description("The interval between session maintenance activities.")
.required(true)
.expressionLanguageSupported(true)
.addValidator(StandardValidators.TIME_PERIOD_VALIDATOR)
.defaultValue("10 sec")
.build();
private static final List<PropertyDescriptor> properties;
static {
@ -89,6 +105,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
props.add(WS_URI);
props.add(SSL_CONTEXT);
props.add(CONNECTION_TIMEOUT);
props.add(SESSION_MAINTENANCE_INTERVAL);
properties = Collections.unmodifiableList(props);
}
@ -96,6 +113,8 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
private WebSocketClient client;
private URI webSocketUri;
private long connectionTimeoutMillis;
private volatile ScheduledExecutorService sessionMaintenanceScheduler;
private final ReentrantLock connectionLock = new ReentrantLock();
@Override
protected List<PropertyDescriptor> getSupportedPropertyDescriptors() {
@ -116,15 +135,38 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
configurePolicy(context, client.getPolicy());
client.start();
activeSessions.clear();
webSocketUri = new URI(context.getProperty(WS_URI).getValue());
connectionTimeoutMillis = context.getProperty(CONNECTION_TIMEOUT).asTimePeriod(TimeUnit.MILLISECONDS);
final Long sessionMaintenanceInterval = context.getProperty(SESSION_MAINTENANCE_INTERVAL).asTimePeriod(TimeUnit.MILLISECONDS);
sessionMaintenanceScheduler = Executors.newSingleThreadScheduledExecutor();
sessionMaintenanceScheduler.scheduleAtFixedRate(() -> {
try {
maintainSessions();
} catch (final Exception e) {
getLogger().warn("Failed to maintain sessions due to {}", new Object[]{e}, e);
}
}, sessionMaintenanceInterval, sessionMaintenanceInterval, TimeUnit.MILLISECONDS);
}
@OnDisabled
@OnShutdown
@Override
public void stopClient() throws Exception {
activeSessions.clear();
if (sessionMaintenanceScheduler != null) {
try {
sessionMaintenanceScheduler.shutdown();
} catch (Exception e) {
getLogger().warn("Failed to shutdown session maintainer due to {}", new Object[]{e}, e);
}
sessionMaintenanceScheduler = null;
}
if (client == null) {
return;
}
@ -135,27 +177,81 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen
@Override
public void connect(final String clientId) throws IOException {
connect(clientId, null);
}
private void connect(final String clientId, String sessionId) throws IOException {
connectionLock.lock();
final WebSocketMessageRouter router;
try {
router = routers.getRouterOrFail(clientId);
} catch (WebSocketConfigurationException e) {
throw new IllegalStateException("Failed to get router due to: " + e, e);
final WebSocketMessageRouter router;
try {
router = routers.getRouterOrFail(clientId);
} catch (WebSocketConfigurationException e) {
throw new IllegalStateException("Failed to get router due to: " + e, e);
}
final RoutingWebSocketListener listener = new RoutingWebSocketListener(router);
listener.setSessionId(sessionId);
final ClientUpgradeRequest request = new ClientUpgradeRequest();
final Future<Session> connect = client.connect(listener, webSocketUri, request);
getLogger().info("Connecting to : {}", new Object[]{webSocketUri});
final Session session;
try {
session = connect.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS);
} catch (Exception e) {
throw new IOException("Failed to connect " + webSocketUri + " due to: " + e, e);
}
getLogger().info("Connected, session={}", new Object[]{session});
activeSessions.put(clientId, listener.getSessionId());
} finally {
connectionLock.unlock();
}
final RoutingWebSocketListener listener = new RoutingWebSocketListener(router);
final ClientUpgradeRequest request = new ClientUpgradeRequest();
final Future<Session> connect = client.connect(listener, webSocketUri, request);
getLogger().info("Connecting to : {}", new Object[]{webSocketUri});
}
final Session session;
private Map<String, String> activeSessions = new ConcurrentHashMap<>();
void maintainSessions() throws Exception {
if (client == null) {
return;
}
connectionLock.lock();
final ComponentLog logger = getLogger();
try {
session = connect.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS);
} catch (Exception e) {
throw new IOException("Failed to connect " + webSocketUri + " due to: " + e, e);
}
getLogger().info("Connected, session={}", new Object[]{session});
// Loop through existing sessions and reconnect.
for (String clientId : activeSessions.keySet()) {
final WebSocketMessageRouter router;
try {
router = routers.getRouterOrFail(clientId);
} catch (final WebSocketConfigurationException e) {
if (logger.isDebugEnabled()) {
logger.debug("The clientId {} is no longer active. Discarding the clientId.", new Object[]{clientId});
}
activeSessions.remove(clientId);
continue;
}
final String sessionId = activeSessions.get(clientId);
// If this session is stil alive, do nothing.
if (!router.containsSession(sessionId)) {
// This session is no longer active, reconnect it.
// If it fails, the sessionId will remain in activeSessions, and retries later.
connect(clientId, sessionId);
}
}
} finally {
connectionLock.unlock();
}
if (logger.isDebugEnabled()) {
logger.debug("Session maintenance completed. activeSessions={}", new Object[]{activeSessions});
}
}
@Override

View File

@ -33,7 +33,11 @@ public class RoutingWebSocketListener extends WebSocketAdapter {
@Override
public void onWebSocketConnect(final Session session) {
super.onWebSocketConnect(session);
sessionId = UUID.randomUUID().toString();
if (sessionId == null || sessionId.isEmpty()) {
// If sessionId is already assigned to this instance, don't publish new one.
// So that existing sesionId can be reused when reconnecting.
sessionId = UUID.randomUUID().toString();
}
final JettyWebSocketSession webSocketSession = new JettyWebSocketSession(sessionId, session);
router.captureSession(webSocketSession);
}
@ -53,4 +57,12 @@ public class RoutingWebSocketListener extends WebSocketAdapter {
public void onWebSocketBinary(final byte[] payload, final int offset, final int len) {
router.onWebSocketBinary(sessionId, payload, offset, len);
}
public void setSessionId(String sessionId) {
this.sessionId = sessionId;
}
public String getSessionId() {
return sessionId;
}
}

View File

@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.websocket;
package org.apache.nifi.websocket.jetty;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.ValidationContext;

View File

@ -14,10 +14,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.websocket;
package org.apache.nifi.websocket.jetty;
import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.websocket.jetty.JettyWebSocketClient;
import org.junit.Test;
import java.util.Collection;

View File

@ -14,11 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.websocket;
package org.apache.nifi.websocket.jetty;
import org.apache.nifi.processor.Processor;
import org.apache.nifi.websocket.jetty.JettyWebSocketClient;
import org.apache.nifi.websocket.jetty.JettyWebSocketServer;
import org.apache.nifi.websocket.BinaryMessageConsumer;
import org.apache.nifi.websocket.ConnectedListener;
import org.apache.nifi.websocket.TextMessageConsumer;
import org.apache.nifi.websocket.WebSocketClientService;
import org.apache.nifi.websocket.WebSocketServerService;
import org.apache.nifi.websocket.WebSocketSessionInfo;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@ -176,6 +180,84 @@ public class TestJettyWebSocketCommunication {
serverService.deregisterProcessor(serverPath, serverProcessor);
}
@Test
public void testClientServerCommunicationRecovery() throws Exception {
assumeFalse(isWindowsEnvironment());
// Expectations.
final CountDownLatch serverIsConnectedByClient = new CountDownLatch(1);
final CountDownLatch clientConnectedServer = new CountDownLatch(1);
final CountDownLatch serverReceivedTextMessageFromClient = new CountDownLatch(1);
final CountDownLatch serverReceivedBinaryMessageFromClient = new CountDownLatch(1);
final CountDownLatch clientReceivedTextMessageFromServer = new CountDownLatch(1);
final CountDownLatch clientReceivedBinaryMessageFromServer = new CountDownLatch(1);
final String textMessageFromClient = "Message from client.";
final String textMessageFromServer = "Message from server.";
final MockWebSocketProcessor serverProcessor = mock(MockWebSocketProcessor.class);
doReturn("serverProcessor1").when(serverProcessor).getIdentifier();
final AtomicReference<String> serverSessionIdRef = new AtomicReference<>();
doAnswer(invocation -> assertConnectedEvent(serverIsConnectedByClient, serverSessionIdRef, invocation))
.when(serverProcessor).connected(any(WebSocketSessionInfo.class));
doAnswer(invocation -> assertConsumeTextMessage(serverReceivedTextMessageFromClient, textMessageFromClient, invocation))
.when(serverProcessor).consume(any(WebSocketSessionInfo.class), anyString());
doAnswer(invocation -> assertConsumeBinaryMessage(serverReceivedBinaryMessageFromClient, textMessageFromClient, invocation))
.when(serverProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());
serverService.registerProcessor(serverPath, serverProcessor);
final String clientId = "client1";
final MockWebSocketProcessor clientProcessor = mock(MockWebSocketProcessor.class);
doReturn("clientProcessor1").when(clientProcessor).getIdentifier();
final AtomicReference<String> clientSessionIdRef = new AtomicReference<>();
doAnswer(invocation -> assertConnectedEvent(clientConnectedServer, clientSessionIdRef, invocation))
.when(clientProcessor).connected(any(WebSocketSessionInfo.class));
doAnswer(invocation -> assertConsumeTextMessage(clientReceivedTextMessageFromServer, textMessageFromServer, invocation))
.when(clientProcessor).consume(any(WebSocketSessionInfo.class), anyString());
doAnswer(invocation -> assertConsumeBinaryMessage(clientReceivedBinaryMessageFromServer, textMessageFromServer, invocation))
.when(clientProcessor).consume(any(WebSocketSessionInfo.class), any(byte[].class), anyInt(), anyInt());
clientService.registerProcessor(clientId, clientProcessor);
clientService.connect(clientId);
assertTrue("WebSocket client should be able to fire connected event.", clientConnectedServer.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket server should be able to fire connected event.", serverIsConnectedByClient.await(5, TimeUnit.SECONDS));
// Nothing happens if maintenance is executed while sessions are alive.
((JettyWebSocketClient) clientService).maintainSessions();
// Restart server.
serverService.stopServer();
serverService.startServer(serverServiceContext.getConfigurationContext());
// Sessions will be recreated with the same session ids.
((JettyWebSocketClient) clientService).maintainSessions();
clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendString(textMessageFromClient));
clientService.sendMessage(clientId, clientSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromClient.getBytes())));
assertTrue("WebSocket server should be able to consume text message.", serverReceivedTextMessageFromClient.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket server should be able to consume binary message.", serverReceivedBinaryMessageFromClient.await(5, TimeUnit.SECONDS));
serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendString(textMessageFromServer));
serverService.sendMessage(serverPath, serverSessionIdRef.get(), sender -> sender.sendBinary(ByteBuffer.wrap(textMessageFromServer.getBytes())));
assertTrue("WebSocket client should be able to consume text message.", clientReceivedTextMessageFromServer.await(5, TimeUnit.SECONDS));
assertTrue("WebSocket client should be able to consume binary message.", clientReceivedBinaryMessageFromServer.await(5, TimeUnit.SECONDS));
clientService.deregisterProcessor(clientId, clientProcessor);
serverService.deregisterProcessor(serverPath, serverProcessor);
}
protected Object assertConnectedEvent(CountDownLatch latch, AtomicReference<String> sessionIdRef, InvocationOnMock invocation) {
final WebSocketSessionInfo sessionInfo = invocation.getArgumentAt(0, WebSocketSessionInfo.class);
assertNotNull(sessionInfo.getLocalAddress());

View File

@ -14,10 +14,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.websocket;
package org.apache.nifi.websocket.jetty;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.ssl.StandardSSLContextService;
import org.apache.nifi.websocket.WebSocketService;
import org.junit.Test;

View File

@ -14,10 +14,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.websocket;
package org.apache.nifi.websocket.jetty;
import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.websocket.jetty.JettyWebSocketServer;
import org.junit.Test;
import java.util.Collection;