NIFI-9878 Added timeout handling for Cache Client handshaking

This closes #6414

Co-authored-by: Nissim Shiman <nshiman@yahoo.com>
Co-authored-by: Jon Shoemaker <jon.l.shoemaker@systolic.com>
Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Jon Shoemaker 2022-09-13 15:48:29 +00:00 committed by exceptionfactory
parent b862fff8f0
commit 9a4ce2607d
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
4 changed files with 101 additions and 14 deletions

View File

@ -78,7 +78,7 @@ public class CacheClientChannelInitializer extends ChannelInitializer<Channel> {
final VersionNegotiator versionNegotiator = versionNegotiatorFactory.create(); final VersionNegotiator versionNegotiator = versionNegotiatorFactory.create();
channelPipeline.addFirst(new IdleStateHandler(idleTimeout.getSeconds(), idleTimeout.getSeconds(), idleTimeout.getSeconds(), TimeUnit.SECONDS)); channelPipeline.addFirst(new IdleStateHandler(idleTimeout.getSeconds(), idleTimeout.getSeconds(), idleTimeout.getSeconds(), TimeUnit.SECONDS));
channelPipeline.addLast(new WriteTimeoutHandler(writeTimeout.toMillis(), TimeUnit.MILLISECONDS)); channelPipeline.addLast(new WriteTimeoutHandler(writeTimeout.toMillis(), TimeUnit.MILLISECONDS));
channelPipeline.addLast(new CacheClientHandshakeHandler(channel, versionNegotiator)); channelPipeline.addLast(new CacheClientHandshakeHandler(channel, versionNegotiator, writeTimeout.toMillis()));
channelPipeline.addLast(new CacheClientRequestHandler()); channelPipeline.addLast(new CacheClientRequestHandler());
channelPipeline.addLast(new CloseContextIdleStateHandler()); channelPipeline.addLast(new CloseContextIdleStateHandler());
} }

View File

@ -32,6 +32,7 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
/** /**
@ -64,24 +65,37 @@ public class CacheClientHandshakeHandler extends ChannelInboundHandlerAdapter {
*/ */
private final VersionNegotiator versionNegotiator; private final VersionNegotiator versionNegotiator;
/**
* THe network timeout associated with handshake completion
*/
private final long timeoutMillis;
/** /**
* Constructor. * Constructor.
* *
* @param channel the channel to which this {@link io.netty.channel.ChannelHandler} is bound. * @param channel the channel to which this {@link io.netty.channel.ChannelHandler} is bound.
* @param versionNegotiator coordinator used to broker the version of the distributed cache protocol with the service * @param versionNegotiator coordinator used to broker the version of the distributed cache protocol with the service
* @param timeoutMillis the network timeout associated with handshake completion
*/ */
public CacheClientHandshakeHandler(final Channel channel, final VersionNegotiator versionNegotiator) { public CacheClientHandshakeHandler(final Channel channel, final VersionNegotiator versionNegotiator,
final long timeoutMillis) {
this.promiseHandshakeComplete = channel.newPromise(); this.promiseHandshakeComplete = channel.newPromise();
this.protocol = new AtomicInteger(PROTOCOL_UNINITIALIZED); this.protocol = new AtomicInteger(PROTOCOL_UNINITIALIZED);
this.versionNegotiator = versionNegotiator; this.versionNegotiator = versionNegotiator;
this.timeoutMillis = timeoutMillis;
} }
/** /**
* API providing client application with visibility into the handshake process. Distributed cache requests * API providing client application with visibility into the handshake process. Distributed cache requests
* should not be sent using this {@link Channel} until the handshake is complete. * should not be sent using this {@link Channel} until the handshake is complete. Since the handshake might fail,
* {@link #isSuccess()} should be called after this method completes.
*/ */
public void waitHandshakeComplete() { public void waitHandshakeComplete() {
promiseHandshakeComplete.awaitUninterruptibly(); promiseHandshakeComplete.awaitUninterruptibly(timeoutMillis, TimeUnit.MILLISECONDS);
if (!promiseHandshakeComplete.isSuccess()) {
HandshakeException ex = new HandshakeException("Handshake timed out before completion.");
promiseHandshakeComplete.setFailure(ex);
}
} }
/** /**
@ -157,4 +171,22 @@ public class CacheClientHandshakeHandler extends ChannelInboundHandlerAdapter {
promiseHandshakeComplete.setSuccess(); promiseHandshakeComplete.setSuccess();
} }
} }
/**
* Returns if the handshake completed successfully
*
* @return success/failure of handshake
*/
public boolean isSuccess() {
return promiseHandshakeComplete.isSuccess();
}
/**
* Return reason for handshake failure.
*
* @return cause for handshake failure or null on success
*/
public Throwable cause() {
return promiseHandshakeComplete.cause();
}
} }

View File

@ -90,16 +90,20 @@ public class CacheClientRequestHandler extends ChannelInboundHandlerAdapter {
public void invoke(final Channel channel, final OutboundAdapter outboundAdapter, final InboundAdapter inboundAdapter) throws IOException { public void invoke(final Channel channel, final OutboundAdapter outboundAdapter, final InboundAdapter inboundAdapter) throws IOException {
final CacheClientHandshakeHandler handshakeHandler = channel.pipeline().get(CacheClientHandshakeHandler.class); final CacheClientHandshakeHandler handshakeHandler = channel.pipeline().get(CacheClientHandshakeHandler.class);
handshakeHandler.waitHandshakeComplete(); handshakeHandler.waitHandshakeComplete();
if (handshakeHandler.getVersionNegotiator().getVersion() < outboundAdapter.getMinimumVersion()) { if (handshakeHandler.isSuccess()) {
throw new UnsupportedOperationException("Remote cache server doesn't support protocol version " + outboundAdapter.getMinimumVersion()); if (handshakeHandler.getVersionNegotiator().getVersion() < outboundAdapter.getMinimumVersion()) {
} throw new UnsupportedOperationException("Remote cache server doesn't support protocol version " + outboundAdapter.getMinimumVersion());
this.inboundAdapter = inboundAdapter; }
channelPromise = channel.newPromise(); this.inboundAdapter = inboundAdapter;
channel.writeAndFlush(Unpooled.wrappedBuffer(outboundAdapter.toBytes())); channelPromise = channel.newPromise();
channelPromise.awaitUninterruptibly(); channel.writeAndFlush(Unpooled.wrappedBuffer(outboundAdapter.toBytes()));
this.inboundAdapter = new NullInboundAdapter(); channelPromise.awaitUninterruptibly();
if (channelPromise.cause() != null) { this.inboundAdapter = new NullInboundAdapter();
throw new IOException("Request invocation failed", channelPromise.cause()); if (channelPromise.cause() != null) {
throw new IOException("Request invocation failed", channelPromise.cause());
}
} else {
throw new IOException("Request invocation failed", handshakeHandler.cause());
} }
} }
} }

View File

@ -29,6 +29,14 @@ import org.apache.nifi.distributed.cache.protocol.ProtocolVersion;
import org.apache.nifi.distributed.cache.server.CacheServer; import org.apache.nifi.distributed.cache.server.CacheServer;
import org.apache.nifi.distributed.cache.server.DistributedCacheServer; import org.apache.nifi.distributed.cache.server.DistributedCacheServer;
import org.apache.nifi.distributed.cache.server.EvictionPolicy; import org.apache.nifi.distributed.cache.server.EvictionPolicy;
import org.apache.nifi.event.transport.EventServer;
import org.apache.nifi.event.transport.configuration.ShutdownQuietPeriod;
import org.apache.nifi.event.transport.configuration.ShutdownTimeout;
import org.apache.nifi.event.transport.configuration.TransportProtocol;
import org.apache.nifi.event.transport.message.ByteArrayMessage;
import org.apache.nifi.event.transport.netty.ByteArrayMessageNettyEventServerFactory;
import org.apache.nifi.event.transport.netty.NettyEventServerFactory;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.DataUnit; import org.apache.nifi.processor.DataUnit;
import org.apache.nifi.processor.Processor; import org.apache.nifi.processor.Processor;
import org.apache.nifi.remote.StandardVersionNegotiator; import org.apache.nifi.remote.StandardVersionNegotiator;
@ -44,6 +52,9 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.Mockito; import org.mockito.Mockito;
import java.net.InetAddress;
import java.net.UnknownHostException;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -52,6 +63,8 @@ import java.nio.charset.StandardCharsets;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
@ -298,6 +311,44 @@ public class TestDistributedMapServerAndClient {
} }
} }
@Test
public void testIncompleteHandshakeScenario() throws InitializationException, IOException {
// Default port used by Distributed Server and Client
final int port = NetworkUtils.getAvailableTcpPort();
// This is used to simulate a DistributedCacheServer that does not complete the handshake response
final BlockingQueue<ByteArrayMessage> messages = new LinkedBlockingQueue<>();
final NettyEventServerFactory serverFactory = getEventServerFactory(port, messages);
final EventServer eventServer = serverFactory.getEventServer();
DistributedMapCacheClientService client = new DistributedMapCacheClientService();
runner.addControllerService("client", client);
runner.setProperty(client, DistributedMapCacheClientService.HOSTNAME, "localhost");
runner.setProperty(client, DistributedMapCacheClientService.PORT, String.valueOf(port));
runner.setProperty(client, DistributedMapCacheClientService.COMMUNICATIONS_TIMEOUT, "250 ms");
runner.enableControllerService(client);
final Serializer<String> valueSerializer = new StringSerializer();
final Serializer<String> keySerializer = new StringSerializer();
final Deserializer<String> deserializer = new StringDeserializer();
try {
assertThrows(IOException.class, () -> client.getAndPutIfAbsent("testKey", "test", keySerializer, valueSerializer, deserializer));
} finally {
eventServer.shutdown();
}
}
private NettyEventServerFactory getEventServerFactory(final int port, final BlockingQueue<ByteArrayMessage> messages) throws UnknownHostException {
final ByteArrayMessageNettyEventServerFactory factory = new ByteArrayMessageNettyEventServerFactory(Mockito.mock(ComponentLog.class),
InetAddress.getByName("127.0.0.1"), port, TransportProtocol.TCP, "\n".getBytes(), 1024, messages);
factory.setWorkerThreads(1);
factory.setShutdownQuietPeriod(ShutdownQuietPeriod.QUICK.getDuration());
factory.setShutdownTimeout(ShutdownTimeout.QUICK.getDuration());
return factory;
}
private DistributedMapCacheClientService createClient(final int port) throws InitializationException { private DistributedMapCacheClientService createClient(final int port) throws InitializationException {
final DistributedMapCacheClientService client = new DistributedMapCacheClientService(); final DistributedMapCacheClientService client = new DistributedMapCacheClientService();
final MockControllerServiceInitializationContext clientInitContext = new MockControllerServiceInitializationContext(client, "client"); final MockControllerServiceInitializationContext clientInitContext = new MockControllerServiceInitializationContext(client, "client");