diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 9eabcc56f28..686432722a2 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -41,6 +41,8 @@ import org.elasticsearch.transport.nio.channel.CloseFuture; import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.TcpReadContext; +import org.elasticsearch.transport.nio.channel.TcpWriteContext; import java.io.IOException; import java.net.InetSocketAddress; @@ -68,7 +70,7 @@ public class NioTransport extends TcpTransport { public static final Setting NIO_ACCEPTOR_COUNT = intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope); - private final TcpReadHandler tcpReadHandler = new TcpReadHandler(this); + private final Consumer contextSetter; private final ConcurrentMap profileToChannelFactory = newConcurrentMap(); private final OpenChannels openChannels = new OpenChannels(logger); private final ArrayList acceptors = new ArrayList<>(); @@ -79,6 +81,7 @@ public class NioTransport extends TcpTransport { public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); + contextSetter = (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(this)), new TcpWriteContext(c)); } @Override @@ -206,7 +209,7 @@ public class NioTransport extends TcpTransport { // loop through all profiles and start them up, special handling for default one for (ProfileSettings profileSettings : profileSettings) { - profileToChannelFactory.putIfAbsent(profileSettings.profileName, new ChannelFactory(profileSettings, tcpReadHandler)); + profileToChannelFactory.putIfAbsent(profileSettings.profileName, new ChannelFactory(profileSettings, contextSetter)); bindServer(profileSettings); } } @@ -243,7 +246,7 @@ public class NioTransport extends TcpTransport { private NioClient createClient() { Supplier selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors); - ChannelFactory channelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), tcpReadHandler); + ChannelFactory channelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), contextSetter); return new NioClient(logger, openChannels, selectorSupplier, defaultConnectionProfile.getConnectTimeout(), channelFactory); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java index f2f92e94e50..199bab9a904 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java @@ -26,7 +26,6 @@ import org.elasticsearch.mocksocket.PrivilegedSocketAccess; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.SocketSelector; -import org.elasticsearch.transport.nio.TcpReadHandler; import java.io.Closeable; import java.io.IOException; @@ -39,15 +38,28 @@ import java.util.function.Consumer; public class ChannelFactory { - private final TcpReadHandler handler; + private final Consumer contextSetter; private final RawChannelFactory rawChannelFactory; - public ChannelFactory(TcpTransport.ProfileSettings profileSettings, TcpReadHandler handler) { - this(new RawChannelFactory(profileSettings), handler); + /** + * This will create a {@link ChannelFactory} using the profile settings and context setter passed to this + * constructor. The context setter must be a {@link Consumer} that calls + * {@link NioSocketChannel#setContexts(ReadContext, WriteContext)} with the appropriate read and write + * contexts. The read and write contexts handle the protocol specific encoding and decoding of messages. + * + * @param profileSettings the profile settings channels opened by this factory + * @param contextSetter a consumer that takes a channel and sets the read and write contexts + */ + public ChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer contextSetter) { + this(new RawChannelFactory(profileSettings.tcpNoDelay, + profileSettings.tcpKeepAlive, + profileSettings.reuseAddress, + Math.toIntExact(profileSettings.sendBufferSize.getBytes()), + Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), contextSetter); } - ChannelFactory(RawChannelFactory rawChannelFactory, TcpReadHandler handler) { - this.handler = handler; + ChannelFactory(RawChannelFactory rawChannelFactory, Consumer contextSetter) { + this.contextSetter = contextSetter; this.rawChannelFactory = rawChannelFactory; } @@ -55,7 +67,7 @@ public class ChannelFactory { Consumer closeListener) throws IOException { SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress); NioSocketChannel channel = new NioSocketChannel(NioChannel.CLIENT, rawChannel, selector); - channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel)); + setContexts(channel); channel.getCloseFuture().addListener(ActionListener.wrap(closeListener::accept, (e) -> closeListener.accept(channel))); scheduleChannel(channel, selector); return channel; @@ -65,7 +77,7 @@ public class ChannelFactory { Consumer closeListener) throws IOException { SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverChannel); NioSocketChannel channel = new NioSocketChannel(serverChannel.getProfile(), rawChannel, selector); - channel.setContexts(new TcpReadContext(channel, handler), new TcpWriteContext(channel)); + setContexts(channel); channel.getCloseFuture().addListener(ActionListener.wrap(closeListener::accept, (e) -> closeListener.accept(channel))); scheduleChannel(channel, selector); return channel; @@ -97,6 +109,12 @@ public class ChannelFactory { } } + private void setContexts(NioSocketChannel channel) { + contextSetter.accept(channel); + assert channel.getReadContext() != null : "read context should have been set on channel"; + assert channel.getWriteContext() != null : "write context should have been set on channel"; + } + static class RawChannelFactory { private final boolean tcpNoDelay; @@ -105,12 +123,13 @@ public class ChannelFactory { private final int tcpSendBufferSize; private final int tcpReceiveBufferSize; - RawChannelFactory(TcpTransport.ProfileSettings profileSettings) { - tcpNoDelay = profileSettings.tcpNoDelay; - tcpKeepAlive = profileSettings.tcpKeepAlive; - tcpReusedAddress = profileSettings.reuseAddress; - tcpSendBufferSize = Math.toIntExact(profileSettings.sendBufferSize.getBytes()); - tcpReceiveBufferSize = Math.toIntExact(profileSettings.receiveBufferSize.getBytes()); + RawChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReusedAddress, int tcpSendBufferSize, + int tcpReceiveBufferSize) { + this.tcpNoDelay = tcpNoDelay; + this.tcpKeepAlive = tcpKeepAlive; + this.tcpReusedAddress = tcpReusedAddress; + this.tcpSendBufferSize = tcpSendBufferSize; + this.tcpReceiveBufferSize = tcpReceiveBufferSize; } SocketChannel openNioChannel(InetSocketAddress remoteAddress) throws IOException { diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java index 8851c37f201..710f26bedcf 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java @@ -26,6 +26,8 @@ import org.elasticsearch.transport.nio.SocketSelector; import org.elasticsearch.transport.nio.TcpReadHandler; import org.junit.After; import org.junit.Before; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.io.IOException; import java.net.InetAddress; @@ -36,6 +38,7 @@ import java.util.function.Consumer; import static org.mockito.Matchers.any; import static org.mockito.Matchers.same; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -55,12 +58,19 @@ public class ChannelFactoryTests extends ESTestCase { @SuppressWarnings("unchecked") public void setupFactory() throws IOException { rawChannelFactory = mock(ChannelFactory.RawChannelFactory.class); - channelFactory = new ChannelFactory(rawChannelFactory, mock(TcpReadHandler.class)); + Consumer contextSetter = mock(Consumer.class); + channelFactory = new ChannelFactory(rawChannelFactory, contextSetter); listener = mock(Consumer.class); socketSelector = mock(SocketSelector.class); acceptingSelector = mock(AcceptingSelector.class); rawChannel = SocketChannel.open(); rawServerChannel = ServerSocketChannel.open(); + + doAnswer(invocationOnMock -> { + NioSocketChannel channel = (NioSocketChannel) invocationOnMock.getArguments()[0]; + channel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + return null; + }).when(contextSetter).accept(any()); } @After