From 984ba82251ca21171bb274eabd29f82d43604c89 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Fri, 2 Aug 2019 17:31:31 -0400 Subject: [PATCH] Move nio channel initialization to event loop (#45155) Currently in the transport-nio work we connect and bind channels on the a thread before the channel is registered with a selector. Additionally, it is at this point that we set all the socket options. This commit moves these operations onto the event-loop after the channel has been registered with a selector. It attempts to set the socket options for a non-server channel at registration time. If that fails, it will attempt to set the options after the channel is connected. This should fix #41071. --- .../nio/BytesChannelContext.java | 6 +- .../org/elasticsearch/nio/ChannelFactory.java | 208 +++++++----------- .../java/org/elasticsearch/nio/Config.java | 94 ++++++++ .../nio/NioServerSocketChannel.java | 6 +- .../elasticsearch/nio/NioSocketChannel.java | 16 +- .../nio/ServerChannelContext.java | 57 ++++- .../nio/SocketChannelContext.java | 88 +++++++- .../nio/BytesChannelContextTests.java | 2 +- .../nio/ChannelFactoryTests.java | 31 +-- .../elasticsearch/nio/EventHandlerTests.java | 33 ++- .../nio/SocketChannelContextTests.java | 72 +++++- .../netty4/Netty4TcpServerChannel.java | 9 +- .../transport/netty4/Netty4Transport.java | 2 +- .../http/nio/NioHttpServerTransport.java | 23 +- .../transport/nio/NioTcpServerChannel.java | 10 +- .../transport/nio/NioTransport.java | 36 +-- .../elasticsearch/http/nio/NioHttpClient.java | 21 +- .../transport/nio/NioGroupFactoryTests.java | 10 +- .../transport/TcpServerChannel.java | 5 - .../transport/nio/MockNioTransport.java | 41 ++-- .../transport/nio/SSLChannelContext.java | 15 +- .../nio/SecurityNioHttpServerTransport.java | 17 +- .../transport/nio/SecurityNioTransport.java | 55 ++--- .../transport/nio/SSLChannelContextTests.java | 14 +- .../SecurityNioHttpServerTransportTests.java | 13 +- 25 files changed, 552 insertions(+), 332 deletions(-) create mode 100644 libs/nio/src/main/java/org/elasticsearch/nio/Config.java diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java index 211e609ba4c..5f257bd0265 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java @@ -24,9 +24,9 @@ import java.util.function.Consumer; public class BytesChannelContext extends SocketChannelContext { - public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, - NioChannelHandler handler, InboundChannelBuffer channelBuffer) { - super(channel, selector, exceptionHandler, handler, channelBuffer); + public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig, + Consumer exceptionHandler, NioChannelHandler handler, InboundChannelBuffer channelBuffer) { + super(channel, selector, socketConfig, exceptionHandler, handler, channelBuffer); } @Override diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java b/libs/nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java index b886f9b68aa..0b613258925 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java @@ -19,58 +19,69 @@ package org.elasticsearch.nio; -import org.elasticsearch.common.CheckedRunnable; - import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; import java.net.InetSocketAddress; -import java.net.SocketException; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; +import java.nio.channels.spi.AbstractSelectableChannel; import java.util.function.Supplier; public abstract class ChannelFactory { + private final boolean tcpNoDelay; + private final boolean tcpKeepAlive; + private final boolean tcpReuseAddress; + private final int tcpSendBufferSize; + private final int tcpReceiveBufferSize; private final ChannelFactory.RawChannelFactory rawChannelFactory; /** - * This will create a {@link ChannelFactory} using the raw channel factory passed to the constructor. - * - * @param rawChannelFactory a factory that will construct the raw socket channels + * This will create a {@link ChannelFactory}. */ - protected ChannelFactory(RawChannelFactory rawChannelFactory) { + protected ChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReuseAddress, int tcpSendBufferSize, + int tcpReceiveBufferSize) { + this(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, new RawChannelFactory()); + } + + /** + * This will create a {@link ChannelFactory} using the raw channel factory passed to the constructor. + */ + protected ChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReuseAddress, int tcpSendBufferSize, + int tcpReceiveBufferSize, RawChannelFactory rawChannelFactory) { + this.tcpNoDelay = tcpNoDelay; + this.tcpKeepAlive = tcpKeepAlive; + this.tcpReuseAddress = tcpReuseAddress; + this.tcpSendBufferSize = tcpSendBufferSize; + this.tcpReceiveBufferSize = tcpReceiveBufferSize; this.rawChannelFactory = rawChannelFactory; } public Socket openNioChannel(InetSocketAddress remoteAddress, Supplier supplier) throws IOException { - SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress); + SocketChannel rawChannel = rawChannelFactory.openNioChannel(); + setNonBlocking(rawChannel); NioSelector selector = supplier.get(); - Socket channel = internalCreateChannel(selector, rawChannel); + Socket channel = internalCreateChannel(selector, rawChannel, createSocketConfig(remoteAddress, false)); scheduleChannel(channel, selector); return channel; } - public Socket acceptNioChannel(ServerChannelContext serverContext, Supplier supplier) throws IOException { - SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverContext); - // Null is returned if there are no pending sockets to accept - if (rawChannel == null) { - return null; - } else { - NioSelector selector = supplier.get(); - Socket channel = internalCreateChannel(selector, rawChannel); - scheduleChannel(channel, selector); - return channel; - } + public Socket acceptNioChannel(SocketChannel rawChannel, Supplier supplier) throws IOException { + setNonBlocking(rawChannel); + NioSelector selector = supplier.get(); + InetSocketAddress remoteAddress = getRemoteAddress(rawChannel); + Socket channel = internalCreateChannel(selector, rawChannel, createSocketConfig(remoteAddress, true)); + scheduleChannel(channel, selector); + return channel; } - public ServerSocket openNioServerSocketChannel(InetSocketAddress address, Supplier supplier) throws IOException { - ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(address); + public ServerSocket openNioServerSocketChannel(InetSocketAddress localAddress, Supplier supplier) throws IOException { + ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(); + setNonBlocking(rawChannel); NioSelector selector = supplier.get(); - ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel); + Config.ServerSocket config = new Config.ServerSocket(tcpReuseAddress, localAddress); + ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel, config); scheduleServerChannel(serverChannel, selector); return serverChannel; } @@ -80,27 +91,38 @@ public abstract class ChannelFactory runnable) throws SocketException { - try { - runnable.run(); - } catch (SocketException e) { - if (MAC_OS_X == false) { - // ignore on Mac, see https://github.com/elastic/elasticsearch/issues/41071 - throw e; - } - } - } - - private void configureSocketChannel(SocketChannel channel) throws IOException { - channel.configureBlocking(false); - java.net.Socket socket = channel.socket(); - setSocketOption(() -> socket.setTcpNoDelay(tcpNoDelay)); - setSocketOption(() -> socket.setKeepAlive(tcpKeepAlive)); - setSocketOption(() -> socket.setReuseAddress(tcpReusedAddress)); - if (tcpSendBufferSize > 0) { - setSocketOption(() -> socket.setSendBufferSize(tcpSendBufferSize)); - } - if (tcpReceiveBufferSize > 0) { - setSocketOption(() -> socket.setSendBufferSize(tcpReceiveBufferSize)); - } - } - - public static SocketChannel accept(ServerSocketChannel serverSocketChannel) throws IOException { - try { - return AccessController.doPrivileged((PrivilegedExceptionAction) serverSocketChannel::accept); - } catch (PrivilegedActionException e) { - throw (IOException) e.getCause(); - } - } - - private static void connect(SocketChannel socketChannel, InetSocketAddress remoteAddress) throws IOException { - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> socketChannel.connect(remoteAddress)); - } catch (PrivilegedActionException e) { - throw (IOException) e.getCause(); - } + ServerSocketChannel openNioServerSocketChannel() throws IOException { + return ServerSocketChannel.open(); } } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/Config.java b/libs/nio/src/main/java/org/elasticsearch/nio/Config.java new file mode 100644 index 00000000000..934a6fe336a --- /dev/null +++ b/libs/nio/src/main/java/org/elasticsearch/nio/Config.java @@ -0,0 +1,94 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import java.net.InetSocketAddress; + +public abstract class Config { + + private final boolean tcpReuseAddress; + + public Config(boolean tcpReuseAddress) { + this.tcpReuseAddress = tcpReuseAddress; + } + + public boolean tcpReuseAddress() { + return tcpReuseAddress; + } + + public static class Socket extends Config { + + private final boolean tcpNoDelay; + private final boolean tcpKeepAlive; + private final int tcpSendBufferSize; + private final int tcpReceiveBufferSize; + private final InetSocketAddress remoteAddress; + private final boolean isAccepted; + + public Socket(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReuseAddress, int tcpSendBufferSize, int tcpReceiveBufferSize, + InetSocketAddress remoteAddress, boolean isAccepted) { + super(tcpReuseAddress); + this.tcpNoDelay = tcpNoDelay; + this.tcpKeepAlive = tcpKeepAlive; + this.tcpSendBufferSize = tcpSendBufferSize; + this.tcpReceiveBufferSize = tcpReceiveBufferSize; + this.remoteAddress = remoteAddress; + this.isAccepted = isAccepted; + } + + public boolean tcpNoDelay() { + return tcpNoDelay; + } + + public boolean tcpKeepAlive() { + return tcpKeepAlive; + } + + public int tcpSendBufferSize() { + return tcpSendBufferSize; + } + + public int tcpReceiveBufferSize() { + return tcpReceiveBufferSize; + } + + public boolean isAccepted() { + return isAccepted; + } + + public InetSocketAddress getRemoteAddress() { + return remoteAddress; + } + } + + public static class ServerSocket extends Config { + + private InetSocketAddress localAddress; + + public ServerSocket(boolean tcpReuseAddress, InetSocketAddress localAddress) { + super(tcpReuseAddress); + this.localAddress = localAddress; + } + + public InetSocketAddress getLocalAddress() { + return localAddress; + } + } +} diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java b/libs/nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java index a335e692588..d18c2ec56dd 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/NioServerSocketChannel.java @@ -22,6 +22,7 @@ package org.elasticsearch.nio; import java.net.InetSocketAddress; import java.nio.channels.ServerSocketChannel; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; public class NioServerSocketChannel extends NioChannel { @@ -32,7 +33,6 @@ public class NioServerSocketChannel extends NioChannel { public NioServerSocketChannel(ServerSocketChannel serverSocketChannel) { this.serverSocketChannel = serverSocketChannel; - attemptToSetLocalAddress(); } /** @@ -49,6 +49,10 @@ public class NioServerSocketChannel extends NioChannel { } } + public void addBindListener(BiConsumer listener) { + context.addBindListener(listener); + } + @Override public InetSocketAddress getLocalAddress() { attemptToSetLocalAddress(); diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/NioSocketChannel.java b/libs/nio/src/main/java/org/elasticsearch/nio/NioSocketChannel.java index c7d44990837..2a9c97610d9 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/NioSocketChannel.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/NioSocketChannel.java @@ -19,8 +19,6 @@ package org.elasticsearch.nio; -import java.io.IOException; -import java.io.UncheckedIOException; import java.net.InetSocketAddress; import java.nio.channels.SocketChannel; import java.util.concurrent.atomic.AtomicBoolean; @@ -30,17 +28,12 @@ public class NioSocketChannel extends NioChannel { private final AtomicBoolean contextSet = new AtomicBoolean(false); private final SocketChannel socketChannel; - private final InetSocketAddress remoteAddress; + private volatile InetSocketAddress remoteAddress; private volatile InetSocketAddress localAddress; - private SocketChannelContext context; + private volatile SocketChannelContext context; public NioSocketChannel(SocketChannel socketChannel) { this.socketChannel = socketChannel; - try { - this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } } public void setContext(SocketChannelContext context) { @@ -70,6 +63,9 @@ public class NioSocketChannel extends NioChannel { } public InetSocketAddress getRemoteAddress() { + if (remoteAddress == null) { + remoteAddress = (InetSocketAddress) socketChannel.socket().getRemoteSocketAddress(); + } return remoteAddress; } @@ -81,7 +77,7 @@ public class NioSocketChannel extends NioChannel { public String toString() { return "NioSocketChannel{" + "localAddress=" + getLocalAddress() + - ", remoteAddress=" + remoteAddress + + ", remoteAddress=" + getRemoteAddress() + '}'; } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java index 9e1af3e9973..ec637a3b046 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/ServerChannelContext.java @@ -19,9 +19,18 @@ package org.elasticsearch.nio; +import org.elasticsearch.common.concurrent.CompletableContext; + import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; @@ -29,23 +38,49 @@ public class ServerChannelContext extends ChannelContext { private final NioServerSocketChannel channel; private final NioSelector selector; + private final Config.ServerSocket config; private final Consumer acceptor; private final AtomicBoolean isClosing = new AtomicBoolean(false); private final ChannelFactory channelFactory; + private final CompletableContext bindContext = new CompletableContext<>(); public ServerChannelContext(NioServerSocketChannel channel, ChannelFactory channelFactory, NioSelector selector, - Consumer acceptor, Consumer exceptionHandler) { + Config.ServerSocket config, Consumer acceptor, + Consumer exceptionHandler) { super(channel.getRawChannel(), exceptionHandler); this.channel = channel; this.channelFactory = channelFactory; this.selector = selector; + this.config = config; this.acceptor = acceptor; } public void acceptChannels(Supplier selectorSupplier) throws IOException { - NioSocketChannel acceptedChannel; - while ((acceptedChannel = channelFactory.acceptNioChannel(this, selectorSupplier)) != null) { - acceptor.accept(acceptedChannel); + SocketChannel acceptedChannel; + while ((acceptedChannel = accept(rawChannel)) != null) { + NioSocketChannel nioChannel = channelFactory.acceptNioChannel(acceptedChannel, selectorSupplier); + acceptor.accept(nioChannel); + } + } + + public void addBindListener(BiConsumer listener) { + bindContext.addListener(listener); + } + + @Override + protected void register() throws IOException { + super.register(); + + configureSocket(rawChannel.socket()); + + InetSocketAddress localAddress = config.getLocalAddress(); + try { + rawChannel.bind(localAddress); + bindContext.complete(null); + } catch (IOException e) { + IOException exception = new IOException("Failed to bind server socket channel {localAddress=" + localAddress + "}.", e); + bindContext.completeExceptionally(exception); + throw exception; } } @@ -66,4 +101,18 @@ public class ServerChannelContext extends ChannelContext { return channel; } + private void configureSocket(ServerSocket socket) throws IOException { + socket.setReuseAddress(config.tcpReuseAddress()); + } + + protected static SocketChannel accept(ServerSocketChannel serverSocketChannel) throws IOException { + try { + assert serverSocketChannel.isBlocking() == false; + SocketChannel channel = AccessController.doPrivileged((PrivilegedExceptionAction) serverSocketChannel::accept); + assert serverSocketChannel.isBlocking() == false; + return channel; + } catch (PrivilegedActionException e) { + throw (IOException) e.getCause(); + } + } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index bc93466b58a..e204cb47907 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -24,9 +24,14 @@ import org.elasticsearch.nio.utils.ByteBufferUtils; import org.elasticsearch.nio.utils.ExceptionsHelper; import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.LinkedList; import java.util.concurrent.atomic.AtomicBoolean; @@ -47,19 +52,23 @@ public abstract class SocketChannelContext extends ChannelContext protected final NioSocketChannel channel; protected final InboundChannelBuffer channelBuffer; protected final AtomicBoolean isClosing = new AtomicBoolean(false); - private final NioChannelHandler readWriteHandler; + private final NioChannelHandler channelHandler; private final NioSelector selector; + private final Config.Socket socketConfig; private final CompletableContext connectContext = new CompletableContext<>(); private final LinkedList pendingFlushes = new LinkedList<>(); private boolean closeNow; + private boolean socketOptionsSet; private Exception connectException; - protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, - NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) { + protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig, + Consumer exceptionHandler, NioChannelHandler channelHandler, + InboundChannelBuffer channelBuffer) { super(channel.getRawChannel(), exceptionHandler); this.selector = selector; this.channel = channel; - this.readWriteHandler = readWriteHandler; + this.socketConfig = socketConfig; + this.channelHandler = channelHandler; this.channelBuffer = channelBuffer; } @@ -73,6 +82,22 @@ public abstract class SocketChannelContext extends ChannelContext return channel; } + @Override + protected void register() throws IOException { + super.register(); + + configureSocket(rawChannel.socket(), false); + + if (socketConfig.isAccepted() == false) { + InetSocketAddress remoteAddress = socketConfig.getRemoteAddress(); + try { + connect(rawChannel, remoteAddress); + } catch (IOException e) { + throw new IOException("Failed to initiate socket channel connection {remoteAddress=" + remoteAddress + "}.", e); + } + } + } + public void addConnectListener(BiConsumer listener) { connectContext.addListener(listener); } @@ -117,6 +142,7 @@ public abstract class SocketChannelContext extends ChannelContext } if (isConnected) { connectContext.complete(null); + configureSocket(rawChannel.socket(), true); } return isConnected; } @@ -127,14 +153,14 @@ public abstract class SocketChannelContext extends ChannelContext return; } - WriteOperation writeOperation = readWriteHandler.createWriteOperation(this, message, listener); + WriteOperation writeOperation = channelHandler.createWriteOperation(this, message, listener); getSelector().queueWrite(writeOperation); } public void queueWriteOperation(WriteOperation writeOperation) { getSelector().assertOnSelectorThread(); - pendingFlushes.addAll(readWriteHandler.writeToBytes(writeOperation)); + pendingFlushes.addAll(channelHandler.writeToBytes(writeOperation)); } public abstract int read() throws IOException; @@ -157,7 +183,7 @@ public abstract class SocketChannelContext extends ChannelContext @Override protected void channelActive() throws IOException { - readWriteHandler.channelActive(); + channelHandler.channelActive(); } @Override @@ -174,14 +200,14 @@ public abstract class SocketChannelContext extends ChannelContext isClosing.set(true); // Poll for new flush operations to close - pendingFlushes.addAll(readWriteHandler.pollFlushOperations()); + pendingFlushes.addAll(channelHandler.pollFlushOperations()); FlushOperation flushOperation; while ((flushOperation = pendingFlushes.pollFirst()) != null) { selector.executeFailedListener(flushOperation.getListener(), new ClosedChannelException()); } try { - readWriteHandler.close(); + channelHandler.close(); } catch (IOException e) { closingExceptions.add(e); } @@ -196,12 +222,12 @@ public abstract class SocketChannelContext extends ChannelContext protected void handleReadBytes() throws IOException { int bytesConsumed = Integer.MAX_VALUE; while (isOpen() && bytesConsumed > 0 && channelBuffer.getIndex() > 0) { - bytesConsumed = readWriteHandler.consumeReads(channelBuffer); + bytesConsumed = channelHandler.consumeReads(channelBuffer); channelBuffer.release(bytesConsumed); } // Some protocols might produce messages to flush during a read operation. - pendingFlushes.addAll(readWriteHandler.pollFlushOperations()); + pendingFlushes.addAll(channelHandler.pollFlushOperations()); } public boolean readyForFlush() { @@ -217,7 +243,7 @@ public abstract class SocketChannelContext extends ChannelContext public abstract boolean selectorShouldClose(); protected boolean closeNow() { - return closeNow || readWriteHandler.closeNow(); + return closeNow || channelHandler.closeNow(); } protected void setCloseNow() { @@ -288,4 +314,42 @@ public abstract class SocketChannelContext extends ChannelContext } return totalBytesFlushed; } + + private void configureSocket(Socket socket, boolean isConnectComplete) throws IOException { + if (socketOptionsSet) { + return; + } + + try { + // Set reuse address first as it must be set before a bind call. Some implementations throw + // exceptions on other socket options if the channel is not connected. But setting reuse first, + // we ensure that it is properly set before any bind attempt. + socket.setReuseAddress(socketConfig.tcpReuseAddress()); + socket.setKeepAlive(socketConfig.tcpKeepAlive()); + socket.setTcpNoDelay(socketConfig.tcpNoDelay()); + int tcpSendBufferSize = socketConfig.tcpSendBufferSize(); + if (tcpSendBufferSize > 0) { + socket.setSendBufferSize(tcpSendBufferSize); + } + int tcpReceiveBufferSize = socketConfig.tcpReceiveBufferSize(); + if (tcpReceiveBufferSize > 0) { + socket.setReceiveBufferSize(tcpReceiveBufferSize); + } + socketOptionsSet = true; + } catch (IOException e) { + if (isConnectComplete) { + throw e; + } + // Ignore if not connect complete. Some implementations fail on setting socket options if the + // socket is not connected. We will try again after connection. + } + } + + private static void connect(SocketChannel socketChannel, InetSocketAddress remoteAddress) throws IOException { + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> socketChannel.connect(remoteAddress)); + } catch (PrivilegedActionException e) { + throw (IOException) e.getCause(); + } + } } diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java index c98e7dc8dfb..1bad1dd2b14 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java @@ -62,7 +62,7 @@ public class BytesChannelContextTests extends ESTestCase { channelBuffer = InboundChannelBuffer.allocatingInstance(); TestReadWriteHandler handler = new TestReadWriteHandler(readConsumer); when(channel.getRawChannel()).thenReturn(rawChannel); - context = new BytesChannelContext(channel, selector, mock(Consumer.class), handler, channelBuffer); + context = new BytesChannelContext(channel, selector, mock(Config.Socket.class), mock(Consumer.class), handler, channelBuffer); when(selector.isOnCurrentThread()).thenReturn(true); ByteBuffer buffer = ByteBuffer.allocate(1 << 14); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java index af4eabefd94..215b0a8042c 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java @@ -31,7 +31,6 @@ import java.nio.channels.SocketChannel; import java.util.function.Supplier; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.same; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -69,10 +68,7 @@ public class ChannelFactoryTests extends ESTestCase { } public void testAcceptChannel() throws IOException { - ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); - when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel); - - NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier); + NioSocketChannel channel = channelFactory.acceptNioChannel(rawChannel, socketSelectorSupplier); verify(socketSelector).scheduleForRegistration(channel); @@ -80,18 +76,16 @@ public class ChannelFactoryTests extends ESTestCase { } public void testAcceptedChannelRejected() throws IOException { - ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); - when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); - expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier)); + expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(rawChannel, socketSelectorSupplier)); assertFalse(rawChannel.isOpen()); } public void testOpenChannel() throws IOException { InetSocketAddress address = mock(InetSocketAddress.class); - when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); + when(rawChannelFactory.openNioChannel()).thenReturn(rawChannel); NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelectorSupplier); @@ -102,7 +96,7 @@ public class ChannelFactoryTests extends ESTestCase { public void testOpenedChannelRejected() throws IOException { InetSocketAddress address = mock(InetSocketAddress.class); - when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); + when(rawChannelFactory.openNioChannel()).thenReturn(rawChannel); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelectorSupplier)); @@ -112,7 +106,7 @@ public class ChannelFactoryTests extends ESTestCase { public void testOpenServerChannel() throws IOException { InetSocketAddress address = mock(InetSocketAddress.class); - when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); + when(rawChannelFactory.openNioServerSocketChannel()).thenReturn(rawServerChannel); NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier); @@ -123,7 +117,7 @@ public class ChannelFactoryTests extends ESTestCase { public void testOpenedServerChannelRejected() throws IOException { InetSocketAddress address = mock(InetSocketAddress.class); - when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); + when(rawChannelFactory.openNioServerSocketChannel()).thenReturn(rawServerChannel); doThrow(new IllegalStateException()).when(acceptingSelector).scheduleForRegistration(any()); expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier)); @@ -134,19 +128,26 @@ public class ChannelFactoryTests extends ESTestCase { private static class TestChannelFactory extends ChannelFactory { TestChannelFactory(RawChannelFactory rawChannelFactory) { - super(rawChannelFactory); + super(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, rawChannelFactory); } @Override - public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { + public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) { NioSocketChannel nioSocketChannel = new NioSocketChannel(channel); nioSocketChannel.setContext(mock(SocketChannelContext.class)); return nioSocketChannel; } @Override - public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException { + public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, + Config.ServerSocket socketConfig) { return new NioServerSocketChannel(channel); } + + @Override + protected InetSocketAddress getRemoteAddress(SocketChannel rawChannel) throws IOException { + // Override this functionality to avoid having to connect the accepted channel + return mock(InetSocketAddress.class); + } } } diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java index 726d87317ff..bf6b215c6eb 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java @@ -23,7 +23,9 @@ import org.elasticsearch.test.ESTestCase; import org.junit.Before; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.ServerSocket; +import java.net.Socket; import java.nio.channels.CancelledKeyException; import java.nio.channels.SelectionKey; import java.nio.channels.ServerSocketChannel; @@ -33,7 +35,6 @@ import java.util.Collections; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; -import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -67,6 +68,7 @@ public class EventHandlerTests extends ESTestCase { SocketChannel rawChannel = mock(SocketChannel.class); when(rawChannel.finishConnect()).thenReturn(true); NioSocketChannel channel = new NioSocketChannel(rawChannel); + when(rawChannel.socket()).thenReturn(mock(Socket.class)); context = new DoNotRegisterSocketContext(channel, selector, channelExceptionHandler, readWriteHandler); channel.setContext(context); handler.handleRegistration(context); @@ -104,22 +106,8 @@ public class EventHandlerTests extends ESTestCase { assertEquals(SelectionKey.OP_ACCEPT, serverContext.getSelectionKey().interestOps()); } - public void testHandleAcceptCallsChannelFactory() throws IOException { - NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class)); - NioSocketChannel nullChannel = null; - when(channelFactory.acceptNioChannel(same(serverContext), same(selectorSupplier))).thenReturn(childChannel, nullChannel); - - handler.acceptChannel(serverContext); - - verify(channelFactory, times(2)).acceptNioChannel(same(serverContext), same(selectorSupplier)); - } - - public void testHandleAcceptCallsServerAcceptCallback() throws IOException { - NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class)); - SocketChannelContext childContext = mock(SocketChannelContext.class); - childChannel.setContext(childContext); + public void testHandleAcceptAccept() throws IOException { ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); - when(channelFactory.acceptNioChannel(same(serverContext), same(selectorSupplier))).thenReturn(childChannel); handler.acceptChannel(serverChannelContext); @@ -254,7 +242,7 @@ public class EventHandlerTests extends ESTestCase { DoNotRegisterSocketContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, NioChannelHandler handler) { - super(channel, selector, exceptionHandler, handler, InboundChannelBuffer.allocatingInstance()); + super(channel, selector, getSocketConfig(), exceptionHandler, handler, InboundChannelBuffer.allocatingInstance()); } @Override @@ -270,7 +258,7 @@ public class EventHandlerTests extends ESTestCase { @SuppressWarnings("unchecked") DoNotRegisterServerContext(NioServerSocketChannel channel, NioSelector selector, Consumer acceptor) { - super(channel, channelFactory, selector, acceptor, mock(Consumer.class)); + super(channel, channelFactory, selector, getServerSocketConfig(), acceptor, mock(Consumer.class)); } @Override @@ -280,4 +268,13 @@ public class EventHandlerTests extends ESTestCase { selectionKey.attach(this); } } + + private static Config.ServerSocket getServerSocketConfig() { + return new Config.ServerSocket(randomBoolean(), mock(InetSocketAddress.class)); + } + + private static Config.Socket getSocketConfig() { + return new Config.Socket(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, mock(InetSocketAddress.class), + randomBoolean()); + } } diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java index 210a27aa109..9dfdef4164b 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -19,12 +19,16 @@ package org.elasticsearch.nio; +import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.test.ESTestCase; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; @@ -40,11 +44,13 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.isNull; import static org.mockito.Matchers.same; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +@SuppressForbidden(reason = "allow call to socket connect") public class SocketChannelContextTests extends ESTestCase { private SocketChannel rawChannel; @@ -55,6 +61,7 @@ public class SocketChannelContextTests extends ESTestCase { private NioSelector selector; private NioChannelHandler handler; private ByteBuffer ioBuffer = ByteBuffer.allocate(1024); + private Socket rawSocket; @SuppressWarnings("unchecked") @Before @@ -76,6 +83,8 @@ public class SocketChannelContextTests extends ESTestCase { ioBuffer.clear(); return ioBuffer; }); + rawSocket = mock(Socket.class); + when(rawChannel.socket()).thenReturn(rawSocket); } public void testIOExceptionSetIfEncountered() throws IOException { @@ -101,6 +110,31 @@ public class SocketChannelContextTests extends ESTestCase { assertTrue(context.closeNow()); } + public void testRegisterInitiatesConnect() throws IOException { + InetSocketAddress address = mock(InetSocketAddress.class); + boolean isAccepted = randomBoolean(); + Config.Socket config; + boolean tcpNoDelay = randomBoolean(); + boolean tcpKeepAlive = randomBoolean(); + boolean tcpReuseAddress = randomBoolean(); + int tcpSendBufferSize = randomIntBetween(1000, 2000); + int tcpReceiveBufferSize = randomIntBetween(1000, 2000); + config = new Config.Socket(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, address, isAccepted); + InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); + TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, buffer, config); + context.register(); + if (isAccepted) { + verify(rawChannel, times(0)).connect(any(InetSocketAddress.class)); + } else { + verify(rawChannel).connect(same(address)); + } + verify(rawSocket).setTcpNoDelay(tcpNoDelay); + verify(rawSocket).setKeepAlive(tcpKeepAlive); + verify(rawSocket).setReuseAddress(tcpReuseAddress); + verify(rawSocket).setSendBufferSize(tcpSendBufferSize); + verify(rawSocket).setReceiveBufferSize(tcpReceiveBufferSize); + } + public void testConnectSucceeds() throws IOException { AtomicBoolean listenerCalled = new AtomicBoolean(false); when(rawChannel.finishConnect()).thenReturn(false, true); @@ -142,6 +176,29 @@ public class SocketChannelContextTests extends ESTestCase { assertSame(ioException, exception.get()); } + public void testConnectCanSetSocketOptions() throws IOException { + InetSocketAddress address = mock(InetSocketAddress.class); + Config.Socket config; + boolean tcpNoDelay = randomBoolean(); + boolean tcpKeepAlive = randomBoolean(); + boolean tcpReuseAddress = randomBoolean(); + int tcpSendBufferSize = randomIntBetween(1000, 2000); + int tcpReceiveBufferSize = randomIntBetween(1000, 2000); + config = new Config.Socket(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, address, false); + InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); + TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, buffer, config); + doThrow(new SocketException()).doNothing().when(rawSocket).setReuseAddress(tcpReuseAddress); + context.register(); + when(rawChannel.finishConnect()).thenReturn(true); + context.connect(); + + verify(rawSocket, times(2)).setReuseAddress(tcpReuseAddress); + verify(rawSocket).setKeepAlive(tcpKeepAlive); + verify(rawSocket).setTcpNoDelay(tcpNoDelay); + verify(rawSocket).setSendBufferSize(tcpSendBufferSize); + verify(rawSocket).setReceiveBufferSize(tcpReceiveBufferSize); + } + public void testChannelActiveCallsHandler() throws IOException { context.channelActive(); verify(handler).channelActive(); @@ -262,7 +319,8 @@ public class SocketChannelContextTests extends ESTestCase { when(channel.getRawChannel()).thenReturn(realChannel); when(channel.isOpen()).thenReturn(true); InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); - BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, handler, buffer); + BytesChannelContext context = new BytesChannelContext(channel, selector, mock(Config.Socket.class), exceptionHandler, handler, + buffer); context.closeFromSelector(); verify(handler).close(); } @@ -379,11 +437,21 @@ public class SocketChannelContextTests extends ESTestCase { assertEquals(1, flushOperation.getBuffersToWrite()[0].position()); } + private static Config.Socket getSocketConfig() { + return new Config.Socket(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, mock(InetSocketAddress.class), + randomBoolean()); + } + private static class TestSocketChannelContext extends SocketChannelContext { private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) { - super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + this(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, getSocketConfig()); + } + + private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, + NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer, Config.Socket config) { + super(channel, selector, config, exceptionHandler, readWriteHandler, channelBuffer); } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4TcpServerChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4TcpServerChannel.java index 830b0a8c203..bc293750eb4 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4TcpServerChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4TcpServerChannel.java @@ -29,20 +29,13 @@ import java.net.InetSocketAddress; public class Netty4TcpServerChannel implements TcpServerChannel { private final Channel channel; - private final String profile; private final CompletableContext closeContext = new CompletableContext<>(); - Netty4TcpServerChannel(Channel channel, String profile) { + Netty4TcpServerChannel(Channel channel) { this.channel = channel; - this.profile = profile; Netty4TcpChannel.addListener(this.channel.closeFuture(), closeContext); } - @Override - public String getProfile() { - return profile; - } - @Override public InetSocketAddress getLocalAddress() { return (InetSocketAddress) channel.localAddress(); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index db6ebb28749..5f29c51a1ce 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -239,7 +239,7 @@ public class Netty4Transport extends TcpTransport { @Override protected Netty4TcpServerChannel bind(String name, InetSocketAddress address) { Channel channel = serverBootstraps.get(name).bind(address).syncUninterruptibly().channel(); - Netty4TcpServerChannel esChannel = new Netty4TcpServerChannel(channel, name); + Netty4TcpServerChannel esChannel = new Netty4TcpServerChannel(channel); channel.attr(SERVER_CHANNEL_KEY).set(esChannel); return esChannel; } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java index 99a0d87e19c..5874030c3ee 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java @@ -22,6 +22,8 @@ package org.elasticsearch.http.nio; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; @@ -33,6 +35,7 @@ import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpServerChannel; import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.ChannelFactory; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSelector; @@ -130,7 +133,11 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { @Override protected HttpServerChannel bind(InetSocketAddress socketAddress) throws IOException { - return nioGroup.bindServerChannel(socketAddress, channelFactory); + NioHttpServerChannel httpServerChannel = nioGroup.bindServerChannel(socketAddress, channelFactory); + PlainActionFuture future = PlainActionFuture.newFuture(); + httpServerChannel.addBindListener(ActionListener.toBiConsumer(future)); + future.actionGet(); + return httpServerChannel; } protected ChannelFactory channelFactory() { @@ -144,27 +151,29 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { private class HttpChannelFactory extends ChannelFactory { private HttpChannelFactory() { - super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize)); + super(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize); } @Override - public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { + public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) { NioHttpChannel httpChannel = new NioHttpChannel(channel); - HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this, + HttpReadWriteHandler handler = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this, handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis); Consumer exceptionHandler = (e) -> onException(httpChannel, e); - SocketChannelContext context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpReadWritePipeline, + SocketChannelContext context = new BytesChannelContext(httpChannel, selector, socketConfig, exceptionHandler, handler, new InboundChannelBuffer(pageAllocator)); httpChannel.setContext(context); return httpChannel; } @Override - public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException { + public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, + Config.ServerSocket socketConfig) { NioHttpServerChannel httpServerChannel = new NioHttpServerChannel(channel); Consumer exceptionHandler = (e) -> onServerException(httpServerChannel, e); Consumer acceptor = NioHttpServerTransport.this::acceptChannel; - ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, acceptor, exceptionHandler); + ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, socketConfig, acceptor, + exceptionHandler); httpServerChannel.setContext(context); return httpServerChannel; } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTcpServerChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTcpServerChannel.java index 0d4b00f14b4..dccc581acea 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTcpServerChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTcpServerChannel.java @@ -31,22 +31,14 @@ import java.nio.channels.ServerSocketChannel; */ public class NioTcpServerChannel extends NioServerSocketChannel implements TcpServerChannel { - private final String profile; - - public NioTcpServerChannel(String profile, ServerSocketChannel socketChannel) { + public NioTcpServerChannel(ServerSocketChannel socketChannel) { super(socketChannel); - this.profile = profile; } public void close() { getContext().closeChannel(); } - @Override - public String getProfile() { - return profile; - } - @Override public void addCloseListener(ActionListener listener) { addCloseListener(ActionListener.toBiConsumer(listener)); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index a39098a3d59..de7fbdd2649 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -23,6 +23,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -31,6 +33,7 @@ import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.ChannelFactory; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSelector; @@ -38,6 +41,7 @@ import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TcpTransport; +import org.elasticsearch.transport.TransportSettings; import java.io.IOException; import java.net.InetSocketAddress; @@ -70,7 +74,11 @@ public class NioTransport extends TcpTransport { @Override protected NioTcpServerChannel bind(String name, InetSocketAddress address) throws IOException { TcpChannelFactory channelFactory = this.profileToChannelFactory.get(name); - return nioGroup.bindServerChannel(address, channelFactory); + NioTcpServerChannel serverChannel = nioGroup.bindServerChannel(address, channelFactory); + PlainActionFuture future = PlainActionFuture.newFuture(); + serverChannel.addBindListener(ActionListener.toBiConsumer(future)); + future.actionGet(); + return serverChannel; } @Override @@ -85,7 +93,7 @@ public class NioTransport extends TcpTransport { try { nioGroup = groupFactory.getTransportGroup(); - ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default"); + ProfileSettings clientProfileSettings = new ProfileSettings(settings, TransportSettings.DEFAULT_PROFILE); clientChannelFactory = clientChannelFactoryFunction(clientProfileSettings); if (NetworkService.NETWORK_SERVER.get(settings)) { @@ -133,8 +141,9 @@ public class NioTransport extends TcpTransport { protected abstract class TcpChannelFactory extends ChannelFactory { - protected TcpChannelFactory(RawChannelFactory rawChannelFactory) { - super(rawChannelFactory); + protected TcpChannelFactory(ProfileSettings profileSettings) { + super(profileSettings.tcpNoDelay, profileSettings.tcpKeepAlive, profileSettings.reuseAddress, + Math.toIntExact(profileSettings.sendBufferSize.getBytes()), Math.toIntExact(profileSettings.receiveBufferSize.getBytes())); } } @@ -144,32 +153,29 @@ public class NioTransport extends TcpTransport { private final String profileName; private TcpChannelFactoryImpl(ProfileSettings profileSettings, boolean isClient) { - super(new RawChannelFactory(profileSettings.tcpNoDelay, - profileSettings.tcpKeepAlive, - profileSettings.reuseAddress, - Math.toIntExact(profileSettings.sendBufferSize.getBytes()), - Math.toIntExact(profileSettings.receiveBufferSize.getBytes()))); + super(profileSettings); this.isClient = isClient; this.profileName = profileSettings.profileName; } @Override - public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) { + public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) { NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); - TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this); + TcpReadWriteHandler handler = new TcpReadWriteHandler(nioChannel, NioTransport.this); Consumer exceptionHandler = (e) -> onException(nioChannel, e); - BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, + BytesChannelContext context = new BytesChannelContext(nioChannel, selector, socketConfig, exceptionHandler, handler, new InboundChannelBuffer(pageAllocator)); nioChannel.setContext(context); return nioChannel; } @Override - public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { - NioTcpServerChannel nioChannel = new NioTcpServerChannel(profileName, channel); + public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, + Config.ServerSocket socketConfig) { + NioTcpServerChannel nioChannel = new NioTcpServerChannel(channel); Consumer exceptionHandler = (e) -> onServerException(nioChannel, e); Consumer acceptor = NioTransport.this::acceptChannel; - ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, acceptor, exceptionHandler); + ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, socketConfig, acceptor, exceptionHandler); nioChannel.setContext(context); return nioChannel; } diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java index c430f2b57d3..da3d0d83e70 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java @@ -31,8 +31,8 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequestEncoder; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; -import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.network.NetworkService; @@ -42,14 +42,15 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.ChannelFactory; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.EventHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.NioSelectorGroup; +import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.NioSelector; +import org.elasticsearch.nio.NioSelectorGroup; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.tasks.Task; @@ -149,7 +150,8 @@ class NioHttpClient implements Closeable { connectFuture.actionGet(); for (HttpRequest request : requests) { - nioSocketChannel.getContext().sendMessage(request, (v, e) -> {}); + nioSocketChannel.getContext().sendMessage(request, (v, e) -> { + }); } if (latch.await(30L, TimeUnit.SECONDS) == false) { fail("Failed to get all expected responses."); @@ -177,17 +179,17 @@ class NioHttpClient implements Closeable { private final Collection content; private ClientChannelFactory(CountDownLatch latch, Collection content) { - super(new RawChannelFactory(NetworkService.TCP_NO_DELAY.get(Settings.EMPTY), + super(NetworkService.TCP_NO_DELAY.get(Settings.EMPTY), NetworkService.TCP_KEEP_ALIVE.get(Settings.EMPTY), NetworkService.TCP_REUSE_ADDRESS.get(Settings.EMPTY), Math.toIntExact(NetworkService.TCP_SEND_BUFFER_SIZE.get(Settings.EMPTY).getBytes()), - Math.toIntExact(NetworkService.TCP_RECEIVE_BUFFER_SIZE.get(Settings.EMPTY).getBytes()))); + Math.toIntExact(NetworkService.TCP_RECEIVE_BUFFER_SIZE.get(Settings.EMPTY).getBytes())); this.latch = latch; this.content = content; } @Override - public NioSocketChannel createChannel(NioSelector selector, java.nio.channels.SocketChannel channel) throws IOException { + public NioSocketChannel createChannel(NioSelector selector, java.nio.channels.SocketChannel channel, Config.Socket socketConfig) { NioSocketChannel nioSocketChannel = new NioSocketChannel(channel); HttpClientHandler handler = new HttpClientHandler(nioSocketChannel, latch, content); Consumer exceptionHandler = (e) -> { @@ -195,14 +197,15 @@ class NioHttpClient implements Closeable { onException(e); nioSocketChannel.close(); }; - SocketChannelContext context = new BytesChannelContext(nioSocketChannel, selector, exceptionHandler, handler, + SocketChannelContext context = new BytesChannelContext(nioSocketChannel, selector, socketConfig, exceptionHandler, handler, InboundChannelBuffer.allocatingInstance()); nioSocketChannel.setContext(context); return nioSocketChannel; } @Override - public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { + public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, + Config.ServerSocket socketConfig) { throw new UnsupportedOperationException("Cannot create server channel"); } } diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioGroupFactoryTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioGroupFactoryTests.java index 356eb39b734..ca13e28d512 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioGroupFactoryTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioGroupFactoryTests.java @@ -21,6 +21,7 @@ package org.elasticsearch.transport.nio; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.nio.ChannelFactory; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioServerSocketChannel; @@ -57,20 +58,21 @@ public class NioGroupFactoryTests extends ESTestCase { private static class BindingFactory extends ChannelFactory { private BindingFactory() { - super(new ChannelFactory.RawChannelFactory(false, false, false, -1, -1)); + super(false, false, false, -1, -1); } @Override - public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { + public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException { throw new IOException("boom"); } @Override - public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException { + public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, + Config.ServerSocket socketConfig) { NioServerSocketChannel nioChannel = new NioServerSocketChannel(channel); Consumer exceptionHandler = (e) -> {}; Consumer acceptor = (c) -> {}; - ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, acceptor, exceptionHandler); + ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, socketConfig, acceptor, exceptionHandler); nioChannel.setContext(context); return nioChannel; } diff --git a/server/src/main/java/org/elasticsearch/transport/TcpServerChannel.java b/server/src/main/java/org/elasticsearch/transport/TcpServerChannel.java index 408ec1af20b..c8edd4c2c9f 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpServerChannel.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpServerChannel.java @@ -31,11 +31,6 @@ import java.net.InetSocketAddress; */ public interface TcpServerChannel extends CloseableChannel { - /** - * This returns the profile for this channel. - */ - String getProfile(); - /** * Returns the local address for this channel. * diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index 00f3ff7bb7c..6f1589755cf 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -25,6 +25,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -37,6 +38,7 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.BytesWriteHandler; import org.elasticsearch.nio.ChannelFactory; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelectorGroup; @@ -91,7 +93,11 @@ public class MockNioTransport extends TcpTransport { @Override protected MockServerChannel bind(String name, InetSocketAddress address) throws IOException { MockTcpChannelFactory channelFactory = this.profileToChannelFactory.get(name); - return nioGroup.bindServerChannel(address, channelFactory); + MockServerChannel serverChannel = nioGroup.bindServerChannel(address, channelFactory); + PlainActionFuture future = PlainActionFuture.newFuture(); + serverChannel.addBindListener(ActionListener.toBiConsumer(future)); + future.actionGet(); + return serverChannel; } @Override @@ -190,17 +196,17 @@ public class MockNioTransport extends TcpTransport { private final String profileName; private MockTcpChannelFactory(boolean isClient, ProfileSettings profileSettings, String profileName) { - super(new RawChannelFactory(profileSettings.tcpNoDelay, + super(profileSettings.tcpNoDelay, profileSettings.tcpKeepAlive, profileSettings.reuseAddress, Math.toIntExact(profileSettings.sendBufferSize.getBytes()), - Math.toIntExact(profileSettings.receiveBufferSize.getBytes()))); + Math.toIntExact(profileSettings.receiveBufferSize.getBytes())); this.isClient = isClient; this.profileName = profileName; } @Override - public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { + public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) { MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel); IntFunction pageSupplier = (length) -> { if (length > PageCacheRecycler.BYTE_PAGE_SIZE) { @@ -211,7 +217,7 @@ public class MockNioTransport extends TcpTransport { } }; MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this); - BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e), + BytesChannelContext context = new BytesChannelContext(nioChannel, selector, socketConfig, (e) -> exceptionCaught(nioChannel, e), readWriteHandler, new InboundChannelBuffer(pageSupplier)); nioChannel.setContext(context); nioChannel.addConnectListener((v, e) -> { @@ -229,18 +235,19 @@ public class MockNioTransport extends TcpTransport { } @Override - public MockServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { - MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel); + public MockServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, Config.ServerSocket socketConfig) { + MockServerChannel nioServerChannel = new MockServerChannel(channel); Consumer exceptionHandler = (e) -> logger.error(() -> new ParameterizedMessage("exception from server channel caught on transport layer [{}]", channel), e); - ServerChannelContext context = new ServerChannelContext(nioServerChannel, null, selector, null, - exceptionHandler) { + ServerChannelContext context = new ServerChannelContext(nioServerChannel, this, selector, socketConfig, + MockNioTransport.this::acceptChannel, exceptionHandler) { @Override public void acceptChannels(Supplier selectorSupplier) throws IOException { int acceptCount = 0; - NioSocketChannel acceptedChannel; - while ((acceptedChannel = MockTcpChannelFactory.this.acceptNioChannel(this, selectorSupplier)) != null) { - acceptChannel(acceptedChannel); + SocketChannel acceptedChannel; + while ((acceptedChannel = accept(rawChannel)) != null) { + NioSocketChannel nioChannel = MockTcpChannelFactory.this.acceptNioChannel(acceptedChannel, selectorSupplier); + acceptChannel(nioChannel); ++acceptCount; if (acceptCount % 100 == 0) { logger.warn("Accepted [{}] connections in a single select loop iteration on [{}]", acceptCount, channel); @@ -272,11 +279,8 @@ public class MockNioTransport extends TcpTransport { private static class MockServerChannel extends NioServerSocketChannel implements TcpServerChannel { - private final String profile; - - MockServerChannel(String profile, ServerSocketChannel channel) { + MockServerChannel(ServerSocketChannel channel) { super(channel); - this.profile = profile; } @Override @@ -284,11 +288,6 @@ public class MockNioTransport extends TcpTransport { getContext().closeChannel(); } - @Override - public String getProfile() { - return profile; - } - @Override public void addCloseListener(ActionListener listener) { addCloseListener(ActionListener.toBiConsumer(listener)); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 6a1684dd024..1c4115f0c01 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -13,6 +13,7 @@ import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.SocketChannelContext; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.WriteOperation; import javax.net.ssl.SSLEngine; @@ -41,15 +42,17 @@ public final class SSLChannelContext extends SocketChannelContext { private final LinkedList encryptedFlushes = new LinkedList<>(); private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER; - SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, - NioChannelHandler readWriteHandler, InboundChannelBuffer applicationBuffer) { - this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(), + SSLChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig, + Consumer exceptionHandler, SSLDriver sslDriver, NioChannelHandler readWriteHandler, + InboundChannelBuffer applicationBuffer) { + this(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(), applicationBuffer); } - SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, - NioChannelHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer) { - super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); + SSLChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig, + Consumer exceptionHandler, SSLDriver sslDriver, NioChannelHandler readWriteHandler, + InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer) { + super(channel, selector, socketConfig, exceptionHandler, readWriteHandler, channelBuffer); this.sslDriver = sslDriver; this.networkReadBuffer = networkReadBuffer; } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java index bf476e5b734..458a44eeaf4 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java @@ -18,6 +18,7 @@ import org.elasticsearch.http.nio.NioHttpServerChannel; import org.elasticsearch.http.nio.NioHttpServerTransport; import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.ChannelFactory; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.NioSelector; @@ -82,17 +83,17 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport { class SecurityHttpChannelFactory extends ChannelFactory { private SecurityHttpChannelFactory() { - super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize)); + super(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize); } @Override - public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { + public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException { NioHttpChannel httpChannel = new NioHttpChannel(channel); HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this, handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos); final NioChannelHandler handler; if (ipFilter != null) { - handler = new NioIPFilter(httpHandler, httpChannel.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME); + handler = new NioIPFilter(httpHandler, socketConfig.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME); } else { handler = httpHandler; } @@ -113,10 +114,10 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport { } SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false); InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator); - context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer, + context = new SSLChannelContext(httpChannel, selector, socketConfig, exceptionHandler, sslDriver, handler, networkBuffer, applicationBuffer); } else { - context = new BytesChannelContext(httpChannel, selector, exceptionHandler, handler, networkBuffer); + context = new BytesChannelContext(httpChannel, selector, socketConfig, exceptionHandler, handler, networkBuffer); } httpChannel.setContext(context); @@ -124,11 +125,13 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport { } @Override - public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { + public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, + Config.ServerSocket socketConfig) { NioHttpServerChannel httpServerChannel = new NioHttpServerChannel(channel); Consumer exceptionHandler = (e) -> onServerException(httpServerChannel, e); Consumer acceptor = SecurityNioHttpServerTransport.this::acceptChannel; - ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, acceptor, exceptionHandler); + ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, socketConfig, acceptor, + exceptionHandler); httpServerChannel.setContext(context); return httpServerChannel; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index 8d22d15612e..d546b88a8ce 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.nio.BytesChannelContext; -import org.elasticsearch.nio.ChannelFactory; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.NioSelector; @@ -111,9 +111,6 @@ public class SecurityNioTransport extends NioTransport { @Override protected Function clientChannelFactoryFunction(ProfileSettings profileSettings) { return (node) -> { - final ChannelFactory.RawChannelFactory rawChannelFactory = new ChannelFactory.RawChannelFactory(profileSettings.tcpNoDelay, - profileSettings.tcpKeepAlive, profileSettings.reuseAddress, Math.toIntExact(profileSettings.sendBufferSize.getBytes()), - Math.toIntExact(profileSettings.receiveBufferSize.getBytes())); SNIHostName serverName; String configuredServerName = node.getAttributes().get("server_name"); if (configuredServerName != null) { @@ -125,7 +122,7 @@ public class SecurityNioTransport extends NioTransport { } else { serverName = null; } - return new SecurityClientTcpChannelFactory(rawChannelFactory, serverName); + return new SecurityClientTcpChannelFactory(profileSettings, serverName); }; } @@ -135,26 +132,18 @@ public class SecurityNioTransport extends NioTransport { private final boolean isClient; private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) { - this(new RawChannelFactory(profileSettings.tcpNoDelay, - profileSettings.tcpKeepAlive, - profileSettings.reuseAddress, - Math.toIntExact(profileSettings.sendBufferSize.getBytes()), - Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), profileSettings.profileName, isClient); - } - - private SecurityTcpChannelFactory(RawChannelFactory rawChannelFactory, String profileName, boolean isClient) { - super(rawChannelFactory); - this.profileName = profileName; + super(profileSettings); + this.profileName = profileSettings.profileName; this.isClient = isClient; } @Override - public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { + public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException { NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this); final NioChannelHandler handler; if (ipFilter != null) { - handler = new NioIPFilter(readWriteHandler, nioChannel.getRemoteAddress(), ipFilter, profileName); + handler = new NioIPFilter(readWriteHandler, socketConfig.getRemoteAddress(), ipFilter, profileName); } else { handler = readWriteHandler; } @@ -163,12 +152,12 @@ public class SecurityNioTransport extends NioTransport { SocketChannelContext context; if (sslEnabled) { - SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient); + SSLDriver sslDriver = new SSLDriver(createSSLEngine(socketConfig), pageAllocator, isClient); InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator); - context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer, + context = new SSLChannelContext(nioChannel, selector, socketConfig, exceptionHandler, sslDriver, handler, networkBuffer, applicationBuffer); } else { - context = new BytesChannelContext(nioChannel, selector, exceptionHandler, handler, networkBuffer); + context = new BytesChannelContext(nioChannel, selector, socketConfig, exceptionHandler, handler, networkBuffer); } nioChannel.setContext(context); @@ -176,24 +165,25 @@ public class SecurityNioTransport extends NioTransport { } @Override - public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException { - NioTcpServerChannel nioChannel = new NioTcpServerChannel(profileName, channel); + public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, + Config.ServerSocket socketConfig) { + NioTcpServerChannel nioChannel = new NioTcpServerChannel(channel); Consumer exceptionHandler = (e) -> onServerException(nioChannel, e); Consumer acceptor = SecurityNioTransport.this::acceptChannel; - ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, acceptor, exceptionHandler); + ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, socketConfig, acceptor, exceptionHandler); nioChannel.setContext(context); return nioChannel; } - protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException { + protected SSLEngine createSSLEngine(Config.Socket socketConfig) throws IOException { SSLEngine sslEngine; SSLConfiguration defaultConfig = profileConfiguration.get(TransportSettings.DEFAULT_PROFILE); SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig); boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled(); - if (hostnameVerificationEnabled) { - InetSocketAddress inetSocketAddress = (InetSocketAddress) channel.getRemoteAddress(); + if (hostnameVerificationEnabled && socketConfig.isAccepted() == false) { + InetSocketAddress remoteAddress = socketConfig.getRemoteAddress(); // we create the socket based on the name given. don't reverse DNS - sslEngine = sslService.createSSLEngine(sslConfig, inetSocketAddress.getHostString(), inetSocketAddress.getPort()); + sslEngine = sslService.createSSLEngine(sslConfig, remoteAddress.getHostString(), remoteAddress.getPort()); } else { sslEngine = sslService.createSSLEngine(sslConfig, null, -1); } @@ -205,19 +195,20 @@ public class SecurityNioTransport extends NioTransport { private final SNIHostName serverName; - private SecurityClientTcpChannelFactory(RawChannelFactory rawChannelFactory, SNIHostName serverName) { - super(rawChannelFactory, TransportSettings.DEFAULT_PROFILE, true); + private SecurityClientTcpChannelFactory(ProfileSettings profileSettings, SNIHostName serverName) { + super(profileSettings, true); this.serverName = serverName; } @Override - public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { + public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, + Config.ServerSocket socketConfig) { throw new AssertionError("Cannot create TcpServerChannel with client factory"); } @Override - protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException { - SSLEngine sslEngine = super.createSSLEngine(channel); + protected SSLEngine createSSLEngine(Config.Socket socketConfig) throws IOException { + SSLEngine sslEngine = super.createSSLEngine(socketConfig); if (serverName != null) { SSLParameters sslParameters = sslEngine.getSSLParameters(); sslParameters.setServerNames(Collections.singletonList(serverName)); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index 8e0a5ad23af..ca0d07e8c19 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.Page; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.TaskScheduler; import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.test.ESTestCase; @@ -23,6 +24,7 @@ import org.mockito.stubbing.Answer; import javax.net.ssl.SSLException; import java.io.IOException; +import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; @@ -55,6 +57,7 @@ public class SSLChannelContextTests extends ESTestCase { private Consumer exceptionHandler; private SSLDriver sslDriver; private int messageLength; + private Config.Socket socketConfig; @Before @SuppressWarnings("unchecked") @@ -73,7 +76,8 @@ public class SSLChannelContextTests extends ESTestCase { outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {})); when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); - context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); + socketConfig = new Config.Socket(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, mock(InetSocketAddress.class), false); + context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); context.setSelectionKey(mock(SelectionKey.class)); when(selector.isOnCurrentThread()).thenReturn(true); @@ -180,7 +184,7 @@ public class SSLChannelContextTests extends ESTestCase { try (SocketChannel realChannel = SocketChannel.open()) { when(channel.getRawChannel()).thenReturn(realChannel); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); - context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); + context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); when(channel.isOpen()).thenReturn(true); context.closeFromSelector(); @@ -332,8 +336,10 @@ public class SSLChannelContextTests extends ESTestCase { try (SocketChannel realChannel = SocketChannel.open()) { when(channel.getRawChannel()).thenReturn(realChannel); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); - context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); + + context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); context.setSelectionKey(mock(SelectionKey.class)); + context.closeChannel(); ArgumentCaptor captor = ArgumentCaptor.forClass(WriteOperation.class); verify(selector).queueWrite(captor.capture()); @@ -373,7 +379,7 @@ public class SSLChannelContextTests extends ESTestCase { when(selector.rawSelector()).thenReturn(realSelector); when(channel.getRawChannel()).thenReturn(realSocket); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); - context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); + context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); context.channelActive(); verify(sslDriver).init(); } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransportTests.java index 76048590cea..7acef50ed4e 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransportTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransportTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.env.Environment; import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.http.NullDispatcher; import org.elasticsearch.http.nio.NioHttpChannel; +import org.elasticsearch.nio.Config; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -77,7 +78,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase { SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SocketChannel socketChannel = mock(SocketChannel.class); when(socketChannel.getRemoteAddress()).thenReturn(address); - NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); + NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class)); SSLEngine engine = SSLEngineUtils.getSSLEngine(channel); assertThat(engine.getNeedClientAuth(), is(false)); @@ -99,7 +100,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase { SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SocketChannel socketChannel = mock(SocketChannel.class); when(socketChannel.getRemoteAddress()).thenReturn(address); - NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); + NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class)); SSLEngine engine = SSLEngineUtils.getSSLEngine(channel); assertThat(engine.getNeedClientAuth(), is(false)); assertThat(engine.getWantClientAuth(), is(true)); @@ -120,7 +121,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase { SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SocketChannel socketChannel = mock(SocketChannel.class); when(socketChannel.getRemoteAddress()).thenReturn(address); - NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); + NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class)); SSLEngine engine = SSLEngineUtils.getSSLEngine(channel); assertThat(engine.getNeedClientAuth(), is(true)); assertThat(engine.getWantClientAuth(), is(false)); @@ -141,7 +142,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase { SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SocketChannel socketChannel = mock(SocketChannel.class); when(socketChannel.getRemoteAddress()).thenReturn(address); - NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); + NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class)); SSLEngine engine = SSLEngineUtils.getSSLEngine(channel); assertThat(engine.getNeedClientAuth(), is(false)); assertThat(engine.getWantClientAuth(), is(false)); @@ -159,7 +160,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase { SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SocketChannel socketChannel = mock(SocketChannel.class); when(socketChannel.getRemoteAddress()).thenReturn(address); - NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); + NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class)); SSLEngine defaultEngine = SSLEngineUtils.getSSLEngine(channel); settings = Settings.builder() @@ -173,7 +174,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase { new NetworkService(Collections.emptyList()), mock(BigArrays.class), mock(PageCacheRecycler.class), mock(ThreadPool.class), xContentRegistry(), new NullDispatcher(), mock(IPFilter.class), sslService, nioGroupFactory); factory = transport.channelFactory(); - channel = factory.createChannel(mock(NioSelector.class), socketChannel); + channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class)); SSLEngine customEngine = SSLEngineUtils.getSSLEngine(channel); assertThat(customEngine.getEnabledProtocols(), arrayContaining("TLSv1.2")); assertThat(customEngine.getEnabledProtocols(), not(equalTo(defaultEngine.getEnabledProtocols())));