From ef34555b29428d3470f0b34a34f30b881019e38d Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Wed, 22 Nov 2017 11:39:31 -0600 Subject: [PATCH] Decouple nio constructs from the tcp transport (#27484) This is related to #27260. Currently, basic nio constructs (nio channels, the channel factories, selector event handlers, etc) implement logic that is specific to the tcp transport. For example, NioChannel implements the TcpChannel interface. These nio constructs at some point will also need to support other protocols (ex: http). This commit separates the TcpTransport logic from the nio building blocks. --- .../transport/nio/AcceptorEventHandler.java | 15 +-- .../transport/nio/EventHandler.java | 5 +- .../transport/nio/NioTransport.java | 52 +++++++---- .../transport/nio/OpenChannels.java | 18 ++-- .../transport/nio/SocketEventHandler.java | 9 +- .../transport/nio/TcpReadHandler.java | 3 +- .../nio/channel/AbstractNioChannel.java | 23 ++--- .../transport/nio/channel/ChannelFactory.java | 91 +++++++++---------- .../transport/nio/channel/NioChannel.java | 15 ++- .../nio/channel/NioServerSocketChannel.java | 19 +++- .../nio/channel/NioSocketChannel.java | 25 +++-- .../nio/channel/TcpChannelFactory.java | 66 ++++++++++++++ .../channel/TcpNioServerSocketChannel.java | 57 ++++++++++++ .../nio/channel/TcpNioSocketChannel.java | 55 +++++++++++ .../transport/nio/channel/TcpReadContext.java | 6 +- .../nio/AcceptorEventHandlerTests.java | 29 ++---- .../nio/SimpleNioTransportTests.java | 2 +- .../nio/SocketEventHandlerTests.java | 4 +- .../transport/nio/SocketSelectorTests.java | 1 - .../nio/TestingSocketEventHandler.java | 5 +- .../transport/nio/WriteOperationTests.java | 2 - .../nio/channel/ChannelFactoryTests.java | 34 ++++--- .../channel/NioServerSocketChannelTests.java | 6 +- .../nio/channel/NioSocketChannelTests.java | 20 ++-- .../nio/channel/TcpReadContextTests.java | 6 +- 25 files changed, 376 insertions(+), 192 deletions(-) create mode 100644 test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpChannelFactory.java create mode 100644 test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioServerSocketChannel.java create mode 100644 test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioSocketChannel.java diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java index 49bba47ef02..ba0fa9356b9 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java @@ -22,13 +22,11 @@ package org.elasticsearch.transport.nio; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.transport.nio.channel.ChannelFactory; -import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import org.elasticsearch.transport.nio.channel.SelectionKeyUtils; import java.io.IOException; -import java.util.function.Consumer; import java.util.function.Supplier; /** @@ -37,15 +35,10 @@ import java.util.function.Supplier; public class AcceptorEventHandler extends EventHandler { private final Supplier selectorSupplier; - private final Consumer acceptedChannelCallback; - private final OpenChannels openChannels; - public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier selectorSupplier, - Consumer acceptedChannelCallback) { - super(logger, openChannels); - this.openChannels = openChannels; + public AcceptorEventHandler(Logger logger, Supplier selectorSupplier) { + super(logger); this.selectorSupplier = selectorSupplier; - this.acceptedChannelCallback = acceptedChannelCallback; } /** @@ -56,7 +49,6 @@ public class AcceptorEventHandler extends EventHandler { */ void serverChannelRegistered(NioServerSocketChannel nioServerSocketChannel) { SelectionKeyUtils.setAcceptInterested(nioServerSocketChannel); - openChannels.serverChannelOpened(nioServerSocketChannel); } /** @@ -79,8 +71,7 @@ public class AcceptorEventHandler extends EventHandler { ChannelFactory channelFactory = nioServerChannel.getChannelFactory(); SocketSelector selector = selectorSupplier.get(); NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel, selector); - openChannels.acceptedChannelOpened(nioSocketChannel); - acceptedChannelCallback.accept(nioSocketChannel); + nioServerChannel.getAcceptContext().accept(nioSocketChannel); } /** diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java index 59e866036cc..8521f716162 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java @@ -29,11 +29,9 @@ import java.nio.channels.Selector; public abstract class EventHandler { protected final Logger logger; - private final OpenChannels openChannels; - public EventHandler(Logger logger, OpenChannels openChannels) { + public EventHandler(Logger logger) { this.logger = logger; - this.openChannels = openChannels; } /** @@ -71,7 +69,6 @@ public abstract class EventHandler { * @param channel that should be closed */ void handleClose(NioChannel channel) { - openChannels.channelClosed(channel); try { channel.closeFromSelector(); } catch (IOException e) { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 38c897b3be2..d1ab10fb568 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -33,10 +33,12 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.Transports; -import org.elasticsearch.transport.nio.channel.ChannelFactory; import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.TcpChannelFactory; +import org.elasticsearch.transport.nio.channel.TcpNioServerSocketChannel; +import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel; import org.elasticsearch.transport.nio.channel.TcpReadContext; import org.elasticsearch.transport.nio.channel.TcpWriteContext; @@ -65,12 +67,12 @@ public class NioTransport extends TcpTransport { public static final Setting NIO_ACCEPTOR_COUNT = intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope); - protected final OpenChannels openChannels = new OpenChannels(logger); - private final ConcurrentMap profileToChannelFactory = newConcurrentMap(); + private final OpenChannels openChannels = new OpenChannels(logger); + private final ConcurrentMap profileToChannelFactory = newConcurrentMap(); private final ArrayList acceptors = new ArrayList<>(); private final ArrayList socketSelectors = new ArrayList<>(); private RoundRobinSelectorSupplier clientSelectorSupplier; - private ChannelFactory clientChannelFactory; + private TcpChannelFactory clientChannelFactory; private int acceptorNumber; public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, @@ -84,17 +86,21 @@ public class NioTransport extends TcpTransport { } @Override - protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException { - ChannelFactory channelFactory = this.profileToChannelFactory.get(name); + protected TcpNioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException { + TcpChannelFactory channelFactory = this.profileToChannelFactory.get(name); AcceptingSelector selector = acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings)); - return channelFactory.openNioServerSocketChannel(address, selector); + TcpNioServerSocketChannel serverChannel = channelFactory.openNioServerSocketChannel(address, selector); + openChannels.serverChannelOpened(serverChannel); + serverChannel.addCloseListener(ActionListener.wrap(() -> openChannels.channelClosed(serverChannel))); + return serverChannel; } @Override - protected NioChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener connectListener) + protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener connectListener) throws IOException { - NioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get()); + TcpNioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get()); openChannels.clientChannelOpened(channel); + channel.addCloseListener(ActionListener.wrap(() -> openChannels.channelClosed(channel))); channel.addConnectListener(connectListener); return channel; } @@ -119,14 +125,14 @@ public class NioTransport extends TcpTransport { Consumer clientContextSetter = getContextSetter("client-socket"); clientSelectorSupplier = new RoundRobinSelectorSupplier(socketSelectors); - clientChannelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), clientContextSetter); + ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default"); + clientChannelFactory = new TcpChannelFactory(clientProfileSettings, clientContextSetter, getServerContextSetter()); if (NetworkService.NETWORK_SERVER.get(settings)) { int acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings); for (int i = 0; i < acceptorCount; ++i) { Supplier selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors); - AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, openChannels, selectorSupplier, - this::serverAcceptedChannel); + AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, selectorSupplier); AcceptingSelector acceptor = new AcceptingSelector(eventHandler); acceptors.add(acceptor); } @@ -143,7 +149,8 @@ public class NioTransport extends TcpTransport { for (ProfileSettings profileSettings : profileSettings) { String profileName = profileSettings.profileName; Consumer contextSetter = getContextSetter(profileName); - profileToChannelFactory.putIfAbsent(profileName, new ChannelFactory(profileSettings, contextSetter)); + TcpChannelFactory factory = new TcpChannelFactory(profileSettings, contextSetter, getServerContextSetter()); + profileToChannelFactory.putIfAbsent(profileName, factory); bindServer(profileSettings); } } @@ -169,14 +176,27 @@ public class NioTransport extends TcpTransport { } protected SocketEventHandler getSocketEventHandler() { - return new SocketEventHandler(logger, this::exceptionCaught, openChannels); + return new SocketEventHandler(logger); } final void exceptionCaught(NioSocketChannel channel, Exception exception) { - onException(channel, exception); + onException((TcpNioSocketChannel) channel, exception); } private Consumer getContextSetter(String profileName) { - return (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName,this)), new TcpWriteContext(c)); + return (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName,this)), new TcpWriteContext(c), + this::exceptionCaught); + } + + private void acceptChannel(NioSocketChannel channel) { + TcpNioSocketChannel tcpChannel = (TcpNioSocketChannel) channel; + openChannels.acceptedChannelOpened(tcpChannel); + tcpChannel.addCloseListener(ActionListener.wrap(() -> openChannels.channelClosed(channel))); + serverAcceptedChannel(tcpChannel); + + } + + private Consumer getServerContextSetter() { + return (c) -> c.setAcceptContext(this::acceptChannel); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java index 68bb2f99bf3..12c12aaa48e 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java @@ -25,6 +25,8 @@ import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.TcpNioServerSocketChannel; +import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel; import java.util.ArrayList; import java.util.HashSet; @@ -38,9 +40,9 @@ import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.new public class OpenChannels implements Releasable { // TODO: Maybe set concurrency levels? - private final ConcurrentMap openClientChannels = newConcurrentMap(); - private final ConcurrentMap openAcceptedChannels = newConcurrentMap(); - private final ConcurrentMap openServerChannels = newConcurrentMap(); + private final ConcurrentMap openClientChannels = newConcurrentMap(); + private final ConcurrentMap openAcceptedChannels = newConcurrentMap(); + private final ConcurrentMap openServerChannels = newConcurrentMap(); private final Logger logger; @@ -48,7 +50,7 @@ public class OpenChannels implements Releasable { this.logger = logger; } - public void serverChannelOpened(NioServerSocketChannel channel) { + public void serverChannelOpened(TcpNioServerSocketChannel channel) { boolean added = openServerChannels.putIfAbsent(channel, System.nanoTime()) == null; if (added && logger.isTraceEnabled()) { logger.trace("server channel opened: {}", channel); @@ -59,7 +61,7 @@ public class OpenChannels implements Releasable { return openServerChannels.size(); } - public void acceptedChannelOpened(NioSocketChannel channel) { + public void acceptedChannelOpened(TcpNioSocketChannel channel) { boolean added = openAcceptedChannels.putIfAbsent(channel, System.nanoTime()) == null; if (added && logger.isTraceEnabled()) { logger.trace("accepted channel opened: {}", channel); @@ -70,14 +72,14 @@ public class OpenChannels implements Releasable { return new HashSet<>(openAcceptedChannels.keySet()); } - public void clientChannelOpened(NioSocketChannel channel) { + public void clientChannelOpened(TcpNioSocketChannel channel) { boolean added = openClientChannels.putIfAbsent(channel, System.nanoTime()) == null; if (added && logger.isTraceEnabled()) { logger.trace("client channel opened: {}", channel); } } - public Map getClientChannels() { + public Map getClientChannels() { return openClientChannels; } @@ -105,7 +107,7 @@ public class OpenChannels implements Releasable { @Override public void close() { - Stream channels = Stream.concat(openClientChannels.keySet().stream(), openAcceptedChannels.keySet().stream()); + Stream channels = Stream.concat(openClientChannels.keySet().stream(), openAcceptedChannels.keySet().stream()); TcpChannel.closeChannels(channels.collect(Collectors.toList()), true); openClientChannels.clear(); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java index 46292f63d1b..50362c5a665 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java @@ -27,19 +27,16 @@ import org.elasticsearch.transport.nio.channel.SelectionKeyUtils; import org.elasticsearch.transport.nio.channel.WriteContext; import java.io.IOException; -import java.util.function.BiConsumer; /** * Event handler designed to handle events from non-server sockets */ public class SocketEventHandler extends EventHandler { - private final BiConsumer exceptionHandler; private final Logger logger; - public SocketEventHandler(Logger logger, BiConsumer exceptionHandler, OpenChannels openChannels) { - super(logger, openChannels); - this.exceptionHandler = exceptionHandler; + public SocketEventHandler(Logger logger) { + super(logger); this.logger = logger; } @@ -150,6 +147,6 @@ public class SocketEventHandler extends EventHandler { } private void exceptionCaught(NioSocketChannel channel, Exception e) { - exceptionHandler.accept(channel, e); + channel.getExceptionContext().accept(channel, e); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java index 1260546d34c..5c2ecea54c3 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java @@ -21,6 +21,7 @@ package org.elasticsearch.transport.nio; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel; import java.io.IOException; @@ -34,7 +35,7 @@ public class TcpReadHandler { this.transport = transport; } - public void handleMessage(BytesReference reference, NioSocketChannel channel, int messageBytesLength) { + public void handleMessage(BytesReference reference, TcpNioSocketChannel channel, int messageBytesLength) { try { transport.messageReceived(reference, channel, profile, channel.getRemoteAddress(), messageBytesLength); } catch (IOException e) { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java index 7743fe0d83c..7b08d831df8 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java @@ -137,25 +137,18 @@ public abstract class AbstractNioChannel listener) { closeContext.whenComplete(ActionListener.toBiConsumer(listener)); } - @Override - public void setSoLinger(int value) throws IOException { - if (isOpen()) { - socketChannel.setOption(StandardSocketOptions.SO_LINGER, value); - } + // Package visibility for testing + void setSelectionKey(SelectionKey selectionKey) { + this.selectionKey = selectionKey; + } + // Package visibility for testing + + void closeRawChannel() throws IOException { + socketChannel.close(); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java index 84385de0626..97433cf4d0a 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java @@ -19,74 +19,79 @@ package org.elasticsearch.transport.nio.channel; - import org.elasticsearch.mocksocket.PrivilegedSocketAccess; -import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.SocketSelector; import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; -import java.net.ServerSocket; -import java.net.Socket; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; -import java.util.function.Consumer; -public class ChannelFactory { +public abstract class ChannelFactory { - private final Consumer contextSetter; - private final RawChannelFactory rawChannelFactory; + private final ChannelFactory.RawChannelFactory rawChannelFactory; /** - * This will create a {@link ChannelFactory} using the profile settings and context setter passed to this - * constructor. The context setter must be a {@link Consumer} that calls - * {@link NioSocketChannel#setContexts(ReadContext, WriteContext)} with the appropriate read and write - * contexts. The read and write contexts handle the protocol specific encoding and decoding of messages. + * This will create a {@link ChannelFactory} using the raw channel factory passed to the constructor. * - * @param profileSettings the profile settings channels opened by this factory - * @param contextSetter a consumer that takes a channel and sets the read and write contexts + * @param rawChannelFactory a factory that will construct the raw socket channels */ - public ChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer contextSetter) { - this(new RawChannelFactory(profileSettings.tcpNoDelay, - profileSettings.tcpKeepAlive, - profileSettings.reuseAddress, - Math.toIntExact(profileSettings.sendBufferSize.getBytes()), - Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), contextSetter); - } - - ChannelFactory(RawChannelFactory rawChannelFactory, Consumer contextSetter) { - this.contextSetter = contextSetter; + ChannelFactory(RawChannelFactory rawChannelFactory) { this.rawChannelFactory = rawChannelFactory; } - public NioSocketChannel openNioChannel(InetSocketAddress remoteAddress, SocketSelector selector) throws IOException { + public Socket openNioChannel(InetSocketAddress remoteAddress, SocketSelector selector) throws IOException { SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress); - NioSocketChannel channel = createChannel(selector, rawChannel); + Socket channel = internalCreateChannel(selector, rawChannel); scheduleChannel(channel, selector); return channel; } - public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector) throws IOException { + public Socket acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector) throws IOException { SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverChannel); - NioSocketChannel channel = createChannel(selector, rawChannel); + Socket channel = internalCreateChannel(selector, rawChannel); scheduleChannel(channel, selector); return channel; } - public NioServerSocketChannel openNioServerSocketChannel(InetSocketAddress address, AcceptingSelector selector) - throws IOException { + public ServerSocket openNioServerSocketChannel(InetSocketAddress address, AcceptingSelector selector) throws IOException { ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(address); - NioServerSocketChannel serverChannel = createServerChannel(selector, rawChannel); + ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel); scheduleServerChannel(serverChannel, selector); return serverChannel; } - private NioSocketChannel createChannel(SocketSelector selector, SocketChannel rawChannel) throws IOException { + /** + * This method should return a new {@link NioSocketChannel} implementation. When this method has + * returned, the channel should be fully created and setup. Read and write contexts and the channel + * exception handler should have been set. + * + * @param selector the channel will be registered with + * @param channel the raw channel + * @return the channel + * @throws IOException related to the creation of the channel + */ + public abstract Socket createChannel(SocketSelector selector, SocketChannel channel) throws IOException; + + /** + * This method should return a new {@link NioServerSocketChannel} implementation. When this method has + * returned, the channel should be fully created and setup. + * + * @param selector the channel will be registered with + * @param channel the raw channel + * @return the server channel + * @throws IOException related to the creation of the channel + */ + public abstract ServerSocket createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException; + + private Socket internalCreateChannel(SocketSelector selector, SocketChannel rawChannel) throws IOException { try { - NioSocketChannel channel = new NioSocketChannel(rawChannel, selector); - setContexts(channel); + Socket channel = createChannel(selector, rawChannel); + assert channel.getReadContext() != null : "read context should have been set on channel"; + assert channel.getWriteContext() != null : "write context should have been set on channel"; + assert channel.getExceptionContext() != null : "exception handler should have been set on channel"; return channel; } catch (Exception e) { closeRawChannel(rawChannel, e); @@ -94,16 +99,16 @@ public class ChannelFactory { } } - private NioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel rawChannel) throws IOException { + private ServerSocket internalCreateServerChannel(AcceptingSelector selector, ServerSocketChannel rawChannel) throws IOException { try { - return new NioServerSocketChannel(rawChannel, this, selector); + return createServerChannel(selector, rawChannel); } catch (Exception e) { closeRawChannel(rawChannel, e); throw e; } } - private void scheduleChannel(NioSocketChannel channel, SocketSelector selector) { + private void scheduleChannel(Socket channel, SocketSelector selector) { try { selector.scheduleForRegistration(channel); } catch (IllegalStateException e) { @@ -112,7 +117,7 @@ public class ChannelFactory { } } - private void scheduleServerChannel(NioServerSocketChannel channel, AcceptingSelector selector) { + private void scheduleServerChannel(ServerSocket channel, AcceptingSelector selector) { try { selector.scheduleForRegistration(channel); } catch (IllegalStateException e) { @@ -121,12 +126,6 @@ public class ChannelFactory { } } - private void setContexts(NioSocketChannel channel) { - contextSetter.accept(channel); - assert channel.getReadContext() != null : "read context should have been set on channel"; - assert channel.getWriteContext() != null : "write context should have been set on channel"; - } - private static void closeRawChannel(Closeable c, Exception e) { try { c.close(); @@ -179,7 +178,7 @@ public class ChannelFactory { ServerSocketChannel openNioServerSocketChannel(InetSocketAddress address) throws IOException { ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.configureBlocking(false); - ServerSocket socket = serverSocketChannel.socket(); + java.net.ServerSocket socket = serverSocketChannel.socket(); try { socket.setReuseAddress(tcpReusedAddress); serverSocketChannel.bind(address); @@ -192,7 +191,7 @@ public class ChannelFactory { private void configureSocketChannel(SocketChannel channel) throws IOException { channel.configureBlocking(false); - Socket socket = channel.socket(); + java.net.Socket socket = channel.socket(); socket.setTcpNoDelay(tcpNoDelay); socket.setKeepAlive(tcpKeepAlive); socket.setReuseAddress(tcpReusedAddress); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java index 76262da6f15..93bc4faa4c5 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java @@ -19,7 +19,7 @@ package org.elasticsearch.transport.nio.channel; -import org.elasticsearch.transport.TcpChannel; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.transport.nio.ESSelector; import java.io.IOException; @@ -28,7 +28,9 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.NetworkChannel; import java.nio.channels.SelectionKey; -public interface NioChannel extends TcpChannel { +public interface NioChannel { + + boolean isOpen(); InetSocketAddress getLocalAddress(); @@ -43,4 +45,13 @@ public interface NioChannel extends TcpChannel { SelectionKey getSelectionKey(); NetworkChannel getRawChannel(); + + /** + * Adds a close listener to the channel. Multiple close listeners can be added. There is no guarantee + * about the order in which close listeners will be executed. If the channel is already closed, the + * listener is executed immediately. + * + * @param listener to be called at close + */ + void addCloseListener(ActionListener listener); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java index 0396a53f454..ffbd8f7a987 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java @@ -19,16 +19,16 @@ package org.elasticsearch.transport.nio.channel; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.transport.nio.AcceptingSelector; import java.io.IOException; import java.nio.channels.ServerSocketChannel; +import java.util.function.Consumer; public class NioServerSocketChannel extends AbstractNioChannel { private final ChannelFactory channelFactory; + private Consumer acceptContext; public NioServerSocketChannel(ServerSocketChannel socketChannel, ChannelFactory channelFactory, AcceptingSelector selector) throws IOException { @@ -40,9 +40,18 @@ public class NioServerSocketChannel extends AbstractNioChannel listener) { - throw new UnsupportedOperationException("Cannot send a message to a server channel."); + /** + * This method sets the accept context for a server socket channel. The accept context is called when a + * new channel is accepted. The parameter passed to the context is the new channel. + * + * @param acceptContext to call + */ + public void setAcceptContext(Consumer acceptContext) { + this.acceptContext = acceptContext; + } + + public Consumer getAcceptContext() { + return acceptContext; } @Override diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java index d0c3d9c3330..b56731aee10 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java @@ -20,7 +20,6 @@ package org.elasticsearch.transport.nio.channel; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.transport.nio.NetworkBytesReference; import org.elasticsearch.transport.nio.SocketSelector; @@ -31,14 +30,18 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; import java.util.Arrays; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; public class NioSocketChannel extends AbstractNioChannel { private final InetSocketAddress remoteAddress; private final CompletableFuture connectContext = new CompletableFuture<>(); private final SocketSelector socketSelector; + private final AtomicBoolean contextsSet = new AtomicBoolean(false); private WriteContext writeContext; private ReadContext readContext; + private BiConsumer exceptionContext; private Exception connectException; public NioSocketChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException { @@ -47,11 +50,6 @@ public class NioSocketChannel extends AbstractNioChannel { this.socketSelector = selector; } - @Override - public void sendMessage(BytesReference reference, ActionListener listener) { - writeContext.sendMessage(reference, listener); - } - @Override public void closeFromSelector() throws IOException { assert socketSelector.isOnCurrentThread() : "Should only call from selector thread"; @@ -99,9 +97,14 @@ public class NioSocketChannel extends AbstractNioChannel { return bytesRead; } - public void setContexts(ReadContext readContext, WriteContext writeContext) { - this.readContext = readContext; - this.writeContext = writeContext; + public void setContexts(ReadContext readContext, WriteContext writeContext, BiConsumer exceptionContext) { + if (contextsSet.compareAndSet(false, true)) { + this.readContext = readContext; + this.writeContext = writeContext; + this.exceptionContext = exceptionContext; + } else { + throw new IllegalStateException("Contexts on this channel were already set. They should only be once."); + } } public WriteContext getWriteContext() { @@ -112,6 +115,10 @@ public class NioSocketChannel extends AbstractNioChannel { return readContext; } + public BiConsumer getExceptionContext() { + return exceptionContext; + } + public InetSocketAddress getRemoteAddress() { return remoteAddress; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpChannelFactory.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpChannelFactory.java new file mode 100644 index 00000000000..03d6db18e5a --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpChannelFactory.java @@ -0,0 +1,66 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.transport.TcpTransport; +import org.elasticsearch.transport.nio.AcceptingSelector; +import org.elasticsearch.transport.nio.SocketSelector; + +import java.io.IOException; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.function.Consumer; + +/** + * This is an implementation of {@link ChannelFactory} which returns channels that adhere to the + * {@link org.elasticsearch.transport.TcpChannel} interface. The channels will use the provided + * {@link TcpTransport.ProfileSettings}. The provided context setters will be called with the channel after + * construction. + */ +public class TcpChannelFactory extends ChannelFactory { + + private final Consumer contextSetter; + private final Consumer serverContextSetter; + + public TcpChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer contextSetter, + Consumer serverContextSetter) { + super(new RawChannelFactory(profileSettings.tcpNoDelay, + profileSettings.tcpKeepAlive, + profileSettings.reuseAddress, + Math.toIntExact(profileSettings.sendBufferSize.getBytes()), + Math.toIntExact(profileSettings.receiveBufferSize.getBytes()))); + this.contextSetter = contextSetter; + this.serverContextSetter = serverContextSetter; + } + + @Override + public TcpNioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { + TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(channel, selector); + contextSetter.accept(nioChannel); + return nioChannel; + } + + @Override + public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { + TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(channel, this, selector); + serverContextSetter.accept(nioServerChannel); + return nioServerChannel; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioServerSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioServerSocketChannel.java new file mode 100644 index 00000000000..496295bd320 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioServerSocketChannel.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.transport.TcpChannel; +import org.elasticsearch.transport.nio.AcceptingSelector; + +import java.io.IOException; +import java.nio.channels.ServerSocketChannel; + +/** + * This is an implementation of {@link NioServerSocketChannel} that adheres to the {@link TcpChannel} + * interface. As it is a server socket, setting SO_LINGER and sending messages is not supported. + */ +public class TcpNioServerSocketChannel extends NioServerSocketChannel implements TcpChannel { + + TcpNioServerSocketChannel(ServerSocketChannel socketChannel, TcpChannelFactory channelFactory, AcceptingSelector selector) + throws IOException { + super(socketChannel, channelFactory, selector); + } + + @Override + public void sendMessage(BytesReference reference, ActionListener listener) { + throw new UnsupportedOperationException("Cannot send a message to a server channel."); + } + + @Override + public void setSoLinger(int value) throws IOException { + throw new UnsupportedOperationException("Cannot set SO_LINGER on a server channel."); + } + + @Override + public String toString() { + return "TcpNioServerSocketChannel{" + + "localAddress=" + getLocalAddress() + + '}'; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioSocketChannel.java new file mode 100644 index 00000000000..f1ee1bd4e67 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioSocketChannel.java @@ -0,0 +1,55 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.transport.TcpChannel; +import org.elasticsearch.transport.nio.SocketSelector; + +import java.io.IOException; +import java.net.StandardSocketOptions; +import java.nio.channels.SocketChannel; + +public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel { + + public TcpNioSocketChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException { + super(socketChannel, selector); + } + + public void sendMessage(BytesReference reference, ActionListener listener) { + getWriteContext().sendMessage(reference, listener); + } + + @Override + public void setSoLinger(int value) throws IOException { + if (isOpen()) { + getRawChannel().setOption(StandardSocketOptions.SO_LINGER, value); + } + } + + @Override + public String toString() { + return "TcpNioSocketChannel{" + + "localAddress=" + getLocalAddress() + + ", remoteAddress=" + getRemoteAddress() + + '}'; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java index 57aa16ce15e..8eeb32a976c 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java @@ -34,16 +34,16 @@ public class TcpReadContext implements ReadContext { private static final int DEFAULT_READ_LENGTH = 1 << 14; private final TcpReadHandler handler; - private final NioSocketChannel channel; + private final TcpNioSocketChannel channel; private final TcpFrameDecoder frameDecoder; private final LinkedList references = new LinkedList<>(); private int rawBytesCount = 0; public TcpReadContext(NioSocketChannel channel, TcpReadHandler handler) { - this(channel, handler, new TcpFrameDecoder()); + this((TcpNioSocketChannel) channel, handler, new TcpFrameDecoder()); } - public TcpReadContext(NioSocketChannel channel, TcpReadHandler handler, TcpFrameDecoder frameDecoder) { + public TcpReadContext(TcpNioSocketChannel channel, TcpReadHandler handler, TcpFrameDecoder frameDecoder) { this.handler = handler; this.channel = channel; this.frameDecoder = frameDecoder; diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java index 3f23531407c..aedff1721f8 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java @@ -19,7 +19,6 @@ package org.elasticsearch.transport.nio; -import org.apache.lucene.util.IOUtils; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.nio.channel.ChannelFactory; import org.elasticsearch.transport.nio.channel.DoNotRegisterServerChannel; @@ -34,11 +33,9 @@ import java.nio.channels.SelectionKey; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; +import java.util.function.BiConsumer; import java.util.function.Consumer; -import static org.mockito.Matchers.any; import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -49,7 +46,6 @@ public class AcceptorEventHandlerTests extends ESTestCase { private AcceptorEventHandler handler; private SocketSelector socketSelector; private ChannelFactory channelFactory; - private OpenChannels openChannels; private NioServerSocketChannel channel; private Consumer acceptedChannelCallback; @@ -59,24 +55,16 @@ public class AcceptorEventHandlerTests extends ESTestCase { channelFactory = mock(ChannelFactory.class); socketSelector = mock(SocketSelector.class); acceptedChannelCallback = mock(Consumer.class); - openChannels = new OpenChannels(logger); ArrayList selectors = new ArrayList<>(); selectors.add(socketSelector); - handler = new AcceptorEventHandler(logger, openChannels, new RoundRobinSelectorSupplier(selectors), acceptedChannelCallback); + handler = new AcceptorEventHandler(logger, new RoundRobinSelectorSupplier(selectors)); AcceptingSelector selector = mock(AcceptingSelector.class); channel = new DoNotRegisterServerChannel(mock(ServerSocketChannel.class), channelFactory, selector); + channel.setAcceptContext(acceptedChannelCallback); channel.register(); } - public void testHandleRegisterAdjustsOpenChannels() { - assertEquals(0, openChannels.serverChannelsCount()); - - handler.serverChannelRegistered(channel); - - assertEquals(1, openChannels.serverChannelsCount()); - } - public void testHandleRegisterSetsOP_ACCEPTInterest() { assertEquals(0, channel.getSelectionKey().interestOps()); @@ -96,18 +84,13 @@ public class AcceptorEventHandlerTests extends ESTestCase { } @SuppressWarnings("unchecked") - public void testHandleAcceptAddsToOpenChannelsAndIsRemovedOnClose() throws IOException { - SocketChannel rawChannel = SocketChannel.open(); - NioSocketChannel childChannel = new NioSocketChannel(rawChannel, socketSelector); - childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + public void testHandleAcceptCallsServerAcceptCallback() throws IOException { + NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector); + childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); handler.acceptChannel(channel); verify(acceptedChannelCallback).accept(childChannel); - - assertEquals(new HashSet<>(Collections.singletonList(childChannel)), openChannels.getAcceptedChannels()); - - IOUtils.closeWhileHandlingException(rawChannel); } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java index bc02a89a5c1..55bca45d1c8 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java @@ -77,7 +77,7 @@ public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase { @Override protected SocketEventHandler getSocketEventHandler() { - return new TestingSocketEventHandler(logger, this::exceptionCaught, openChannels); + return new TestingSocketEventHandler(logger); } }; MockTransportService mockTransportService = diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java index cd4e70ab3ac..8f270d11e5a 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java @@ -55,13 +55,13 @@ public class SocketEventHandlerTests extends ESTestCase { public void setUpHandler() throws IOException { exceptionHandler = mock(BiConsumer.class); SocketSelector socketSelector = mock(SocketSelector.class); - handler = new SocketEventHandler(logger, exceptionHandler, mock(OpenChannels.class)); + handler = new SocketEventHandler(logger); rawChannel = mock(SocketChannel.class); channel = new DoNotRegisterChannel(rawChannel, socketSelector); readContext = mock(ReadContext.class); when(rawChannel.finishConnect()).thenReturn(true); - channel.setContexts(readContext, new TcpWriteContext(channel)); + channel.setContexts(readContext, new TcpWriteContext(channel), exceptionHandler); channel.register(); channel.finishConnect(); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java index 0de1bb72063..61a9499f8db 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java @@ -22,7 +22,6 @@ package org.elasticsearch.transport.nio; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import org.elasticsearch.transport.nio.channel.WriteContext; import org.elasticsearch.transport.nio.utils.TestSelectionKey; diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java index 65759cf7705..a3cb92ad376 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java @@ -26,12 +26,11 @@ import java.io.IOException; import java.util.Collections; import java.util.Set; import java.util.WeakHashMap; -import java.util.function.BiConsumer; public class TestingSocketEventHandler extends SocketEventHandler { - public TestingSocketEventHandler(Logger logger, BiConsumer exceptionHandler, OpenChannels openChannels) { - super(logger, exceptionHandler, openChannels); + public TestingSocketEventHandler(Logger logger) { + super(logger); } private Set hasConnectedMap = Collections.newSetFromMap(new WeakHashMap<>()); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java index 1f6f95e62af..351ac87eb56 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java @@ -22,7 +22,6 @@ package org.elasticsearch.transport.nio; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import org.junit.Before; @@ -30,7 +29,6 @@ import java.io.IOException; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class WriteOperationTests extends ESTestCase { diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java index f6bcf26a02c..91e1c2023e7 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java @@ -30,11 +30,10 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; -import java.util.function.Consumer; +import java.util.function.BiConsumer; import static org.mockito.Matchers.any; import static org.mockito.Matchers.same; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -52,19 +51,12 @@ public class ChannelFactoryTests extends ESTestCase { @Before @SuppressWarnings("unchecked") public void setupFactory() throws IOException { - rawChannelFactory = mock(ChannelFactory.RawChannelFactory.class); - Consumer contextSetter = mock(Consumer.class); - channelFactory = new ChannelFactory(rawChannelFactory, contextSetter); + rawChannelFactory = mock(TcpChannelFactory.RawChannelFactory.class); + channelFactory = new TestChannelFactory(rawChannelFactory); socketSelector = mock(SocketSelector.class); acceptingSelector = mock(AcceptingSelector.class); rawChannel = SocketChannel.open(); rawServerChannel = ServerSocketChannel.open(); - - doAnswer(invocationOnMock -> { - NioSocketChannel channel = (NioSocketChannel) invocationOnMock.getArguments()[0]; - channel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); - return null; - }).when(contextSetter).accept(any()); } @After @@ -138,4 +130,24 @@ public class ChannelFactoryTests extends ESTestCase { assertFalse(rawServerChannel.isOpen()); } + + private static class TestChannelFactory extends ChannelFactory { + + TestChannelFactory(RawChannelFactory rawChannelFactory) { + super(rawChannelFactory); + } + + @SuppressWarnings("unchecked") + @Override + public NioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { + NioSocketChannel nioSocketChannel = new NioSocketChannel(channel, selector); + nioSocketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); + return nioSocketChannel; + } + + @Override + public NioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { + return new NioServerSocketChannel(channel, this, selector); + } + } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java index 9c01f5edc61..ba5d47fe8f8 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java @@ -22,10 +22,8 @@ package org.elasticsearch.transport.nio.channel; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.AcceptorEventHandler; -import org.elasticsearch.transport.nio.OpenChannels; import org.junit.After; import org.junit.Before; @@ -33,8 +31,6 @@ import java.io.IOException; import java.nio.channels.ServerSocketChannel; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import java.util.function.Supplier; import static org.mockito.Mockito.mock; @@ -48,7 +44,7 @@ public class NioServerSocketChannelTests extends ESTestCase { @Before @SuppressWarnings("unchecked") public void setSelector() throws IOException { - selector = new AcceptingSelector(new AcceptorEventHandler(logger, mock(OpenChannels.class), mock(Supplier.class), (c) -> {})); + selector = new AcceptingSelector(new AcceptorEventHandler(logger, mock(Supplier.class))); thread = new Thread(selector::runLoop); closedRawChannel = new AtomicBoolean(false); thread.start(); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java index e3053a3e73a..fecaf8fe970 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java @@ -22,8 +22,6 @@ package org.elasticsearch.transport.nio.channel; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.TcpChannel; -import org.elasticsearch.transport.nio.OpenChannels; import org.elasticsearch.transport.nio.SocketEventHandler; import org.elasticsearch.transport.nio.SocketSelector; import org.junit.After; @@ -36,9 +34,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; -import java.util.function.Consumer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -48,13 +44,11 @@ public class NioSocketChannelTests extends ESTestCase { private SocketSelector selector; private AtomicBoolean closedRawChannel; private Thread thread; - private OpenChannels openChannels; @Before @SuppressWarnings("unchecked") public void startSelector() throws IOException { - openChannels = new OpenChannels(logger); - selector = new SocketSelector(new SocketEventHandler(logger, mock(BiConsumer.class), openChannels)); + selector = new SocketSelector(new SocketEventHandler(logger)); thread = new Thread(selector::runLoop); closedRawChannel = new AtomicBoolean(false); thread.start(); @@ -67,13 +61,13 @@ public class NioSocketChannelTests extends ESTestCase { thread.join(); } + @SuppressWarnings("unchecked") public void testClose() throws Exception { AtomicBoolean isClosed = new AtomicBoolean(false); CountDownLatch latch = new CountDownLatch(1); NioSocketChannel socketChannel = new DoNotCloseChannel(mock(SocketChannel.class), selector); - openChannels.clientChannelOpened(socketChannel); - socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); socketChannel.addCloseListener(new ActionListener() { @Override public void onResponse(Void o) { @@ -90,7 +84,6 @@ public class NioSocketChannelTests extends ESTestCase { assertTrue(socketChannel.isOpen()); assertFalse(closedRawChannel.get()); assertFalse(isClosed.get()); - assertTrue(openChannels.getClientChannels().containsKey(socketChannel)); PlainActionFuture closeFuture = PlainActionFuture.newFuture(); socketChannel.addCloseListener(closeFuture); @@ -99,16 +92,16 @@ public class NioSocketChannelTests extends ESTestCase { assertTrue(closedRawChannel.get()); assertFalse(socketChannel.isOpen()); - assertFalse(openChannels.getClientChannels().containsKey(socketChannel)); latch.await(); assertTrue(isClosed.get()); } + @SuppressWarnings("unchecked") public void testConnectSucceeds() throws Exception { SocketChannel rawChannel = mock(SocketChannel.class); when(rawChannel.finishConnect()).thenReturn(true); NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector); - socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); selector.scheduleForRegistration(socketChannel); PlainActionFuture connectFuture = PlainActionFuture.newFuture(); @@ -120,11 +113,12 @@ public class NioSocketChannelTests extends ESTestCase { assertFalse(closedRawChannel.get()); } + @SuppressWarnings("unchecked") public void testConnectFails() throws Exception { SocketChannel rawChannel = mock(SocketChannel.class); when(rawChannel.finishConnect()).thenThrow(new ConnectException()); NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector); - socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); selector.scheduleForRegistration(socketChannel); PlainActionFuture connectFuture = PlainActionFuture.newFuture(); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java index 2dc0b32ae5b..7586b5abd91 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java @@ -39,10 +39,9 @@ import static org.mockito.Mockito.when; public class TcpReadContextTests extends ESTestCase { - private static String PROFILE = "profile"; private TcpReadHandler handler; private int messageLength; - private NioSocketChannel channel; + private TcpNioSocketChannel channel; private TcpReadContext readContext; @Before @@ -50,7 +49,7 @@ public class TcpReadContextTests extends ESTestCase { handler = mock(TcpReadHandler.class); messageLength = randomInt(96) + 4; - channel = mock(NioSocketChannel.class); + channel = mock(TcpNioSocketChannel.class); readContext = new TcpReadContext(channel, handler); } @@ -144,5 +143,4 @@ public class TcpReadContextTests extends ESTestCase { } return bytes; } - }