diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/org/apache/nifi/websocket/jetty/JettyWebSocketClient.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/org/apache/nifi/websocket/jetty/JettyWebSocketClient.java index 823fadf580..83fa91090b 100644 --- a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/org/apache/nifi/websocket/jetty/JettyWebSocketClient.java +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/main/java/org/apache/nifi/websocket/jetty/JettyWebSocketClient.java @@ -101,6 +101,16 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen .defaultValue("3 sec") .build(); + public static final PropertyDescriptor CONNECTION_ATTEMPT_COUNT = new PropertyDescriptor.Builder() + .name("connection-attempt-timeout") + .displayName("Connection Attempt Count") + .description("The number of times to try and establish a connection.") + .required(true) + .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY) + .addValidator(StandardValidators.POSITIVE_INTEGER_VALIDATOR) + .defaultValue("3") + .build(); + public static final PropertyDescriptor SESSION_MAINTENANCE_INTERVAL = new PropertyDescriptor.Builder() .name("session-maintenance-interval") .displayName("Session Maintenance Interval") @@ -183,6 +193,7 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen props.add(WS_URI); props.add(SSL_CONTEXT); props.add(CONNECTION_TIMEOUT); + props.add(CONNECTION_ATTEMPT_COUNT); props.add(SESSION_MAINTENANCE_INTERVAL); props.add(USER_NAME); props.add(USER_PASSWORD); @@ -347,14 +358,23 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen if (!StringUtils.isEmpty(authorizationHeader)) { request.setHeader(HttpHeader.AUTHORIZATION.asString(), authorizationHeader); } - final Future connect = client.connect(listener, webSocketUri, request); - getLogger().info("Connecting to : {}", webSocketUri); - final Session session; - try { - session = connect.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS); - } catch (Exception e) { - throw new IOException("Failed to connect " + webSocketUri + " due to: " + e, e); + final int connectCount = configurationContext.getProperty(CONNECTION_ATTEMPT_COUNT).evaluateAttributeExpressions().asInteger(); + + Session session = null; + for (int i = 0; i < connectCount; i++) { + final Future connect = createWebsocketSession(listener, request); + getLogger().info("Connecting to : {}", webSocketUri); + try { + session = connect.get(connectionTimeoutMillis, TimeUnit.MILLISECONDS); + break; + } catch (Exception e) { + if (i == connectCount - 1) { + throw new IOException("Failed to connect " + webSocketUri + " due to: " + e, e); + } else { + getLogger().warn("Failed to connect to {}, reconnection attempt {}", webSocketUri, i + 1); + } + } } getLogger().info("Connected, session={}", session); activeSessions.put(clientId, new SessionInfo(listener.getSessionId(), flowFileAttributes)); @@ -365,6 +385,10 @@ public class JettyWebSocketClient extends AbstractJettyWebSocketService implemen } + Future createWebsocketSession(RoutingWebSocketListener listener, ClientUpgradeRequest request) throws IOException { + return client.connect(listener, webSocketUri, request); + } + void maintainSessions() throws Exception { if (client == null) { return; diff --git a/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/jetty/ITJettyWebsocketReconnect.java b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/jetty/ITJettyWebsocketReconnect.java new file mode 100644 index 0000000000..08ff79da5b --- /dev/null +++ b/nifi-nar-bundles/nifi-websocket-bundle/nifi-websocket-services-jetty/src/test/java/org/apache/nifi/websocket/jetty/ITJettyWebsocketReconnect.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.websocket.jetty; + +import org.apache.nifi.remote.io.socket.NetworkUtils; +import org.apache.nifi.websocket.WebSocketClientService; +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class ITJettyWebsocketReconnect { + + private ControllerServiceTestContext clientServiceContext; + private WebSocketClientService clientService; + + @BeforeEach + public void setup() throws Exception { + setupClient(); + } + + @AfterEach + public void teardown() throws Exception { + clientService.stopClient(); + } + + private void setupClient() throws Exception { + clientService = new JettyWebSocketTestClient(); + + clientServiceContext = new ControllerServiceTestContext(clientService, "JettyWebSocketClient1"); + clientServiceContext.setCustomValue(JettyWebSocketClient.WS_URI, "ws://localhost:" + NetworkUtils.getAvailableTcpPort() + "/test"); + + clientServiceContext.setCustomValue(JettyWebSocketClient.USER_NAME, "user2"); + clientServiceContext.setCustomValue(JettyWebSocketClient.USER_PASSWORD, "password2"); + + clientService.initialize(clientServiceContext.getInitializationContext()); + clientService.startClient(clientServiceContext.getConfigurationContext()); + } + + @Test + void testClientAttemptsToReconnect() throws Exception { + final ITJettyWebSocketCommunication.MockWebSocketProcessor clientProcessor = mock(ITJettyWebSocketCommunication.MockWebSocketProcessor.class); + doReturn("clientProcessor1").when(clientProcessor).getIdentifier(); + + final String clientId = "client1"; + + clientService.registerProcessor(clientId, clientProcessor); + + assertThrows(IOException.class, + () -> clientService.connect(clientId) + ); + + JettyWebSocketTestClient testClientService = (JettyWebSocketTestClient) clientService; + verify(testClientService.getMockSession(), times(3)).get(anyLong(), any(TimeUnit.class)); + } + + private static class JettyWebSocketTestClient extends JettyWebSocketClient { + private CompletableFuture mockSession; + + public JettyWebSocketTestClient() throws ExecutionException, InterruptedException, TimeoutException { + mockSession = mock(CompletableFuture.class); + when(mockSession.get(anyLong(), any(TimeUnit.class))).thenThrow(new RuntimeException("Test: Connecting timed out.")); + } + + Future createWebsocketSession(RoutingWebSocketListener listener, ClientUpgradeRequest request) { + return mockSession; + } + + public CompletableFuture getMockSession() { + return mockSession; + } + } + +}