Move nio channel initialization to event loop (#45155)

Currently in the transport-nio work we connect and bind channels on the
a thread before the channel is registered with a selector. Additionally,
it is at this point that we set all the socket options. This commit
moves these operations onto the event-loop after the channel has been
registered with a selector. It attempts to set the socket options for a
non-server channel at registration time. If that fails, it will attempt
to set the options after the channel is connected. This should fix
#41071.
This commit is contained in:
Tim Brooks 2019-08-02 17:31:31 -04:00 committed by GitHub
parent ffbe047c32
commit 984ba82251
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 552 additions and 332 deletions

View File

@ -24,9 +24,9 @@ import java.util.function.Consumer;
public class BytesChannelContext extends SocketChannelContext {
public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
NioChannelHandler handler, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, handler, channelBuffer);
public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig,
Consumer<Exception> exceptionHandler, NioChannelHandler handler, InboundChannelBuffer channelBuffer) {
super(channel, selector, socketConfig, exceptionHandler, handler, channelBuffer);
}
@Override

View File

@ -19,58 +19,69 @@
package org.elasticsearch.nio;
import org.elasticsearch.common.CheckedRunnable;
import java.io.Closeable;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.nio.channels.spi.AbstractSelectableChannel;
import java.util.function.Supplier;
public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel, Socket extends NioSocketChannel> {
private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
private final boolean tcpReuseAddress;
private final int tcpSendBufferSize;
private final int tcpReceiveBufferSize;
private final ChannelFactory.RawChannelFactory rawChannelFactory;
/**
* This will create a {@link ChannelFactory} using the raw channel factory passed to the constructor.
*
* @param rawChannelFactory a factory that will construct the raw socket channels
* This will create a {@link ChannelFactory}.
*/
protected ChannelFactory(RawChannelFactory rawChannelFactory) {
protected ChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReuseAddress, int tcpSendBufferSize,
int tcpReceiveBufferSize) {
this(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, new RawChannelFactory());
}
/**
* This will create a {@link ChannelFactory} using the raw channel factory passed to the constructor.
*/
protected ChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReuseAddress, int tcpSendBufferSize,
int tcpReceiveBufferSize, RawChannelFactory rawChannelFactory) {
this.tcpNoDelay = tcpNoDelay;
this.tcpKeepAlive = tcpKeepAlive;
this.tcpReuseAddress = tcpReuseAddress;
this.tcpSendBufferSize = tcpSendBufferSize;
this.tcpReceiveBufferSize = tcpReceiveBufferSize;
this.rawChannelFactory = rawChannelFactory;
}
public Socket openNioChannel(InetSocketAddress remoteAddress, Supplier<NioSelector> supplier) throws IOException {
SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress);
SocketChannel rawChannel = rawChannelFactory.openNioChannel();
setNonBlocking(rawChannel);
NioSelector selector = supplier.get();
Socket channel = internalCreateChannel(selector, rawChannel);
Socket channel = internalCreateChannel(selector, rawChannel, createSocketConfig(remoteAddress, false));
scheduleChannel(channel, selector);
return channel;
}
public Socket acceptNioChannel(ServerChannelContext serverContext, Supplier<NioSelector> supplier) throws IOException {
SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverContext);
// Null is returned if there are no pending sockets to accept
if (rawChannel == null) {
return null;
} else {
NioSelector selector = supplier.get();
Socket channel = internalCreateChannel(selector, rawChannel);
scheduleChannel(channel, selector);
return channel;
}
public Socket acceptNioChannel(SocketChannel rawChannel, Supplier<NioSelector> supplier) throws IOException {
setNonBlocking(rawChannel);
NioSelector selector = supplier.get();
InetSocketAddress remoteAddress = getRemoteAddress(rawChannel);
Socket channel = internalCreateChannel(selector, rawChannel, createSocketConfig(remoteAddress, true));
scheduleChannel(channel, selector);
return channel;
}
public ServerSocket openNioServerSocketChannel(InetSocketAddress address, Supplier<NioSelector> supplier) throws IOException {
ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(address);
public ServerSocket openNioServerSocketChannel(InetSocketAddress localAddress, Supplier<NioSelector> supplier) throws IOException {
ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel();
setNonBlocking(rawChannel);
NioSelector selector = supplier.get();
ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel);
Config.ServerSocket config = new Config.ServerSocket(tcpReuseAddress, localAddress);
ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel, config);
scheduleServerChannel(serverChannel, selector);
return serverChannel;
}
@ -80,27 +91,38 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
* returned, the channel should be fully created and setup. Read and write contexts and the channel
* exception handler should have been set.
*
* @param selector the channel will be registered with
* @param channel the raw channel
* @param selector the channel will be registered with
* @param channel the raw channel
* @param socketConfig the socket config
* @return the channel
* @throws IOException related to the creation of the channel
*/
public abstract Socket createChannel(NioSelector selector, SocketChannel channel) throws IOException;
public abstract Socket createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException;
/**
* This method should return a new {@link NioServerSocketChannel} implementation. When this method has
* returned, the channel should be fully created and setup.
*
* @param selector the channel will be registered with
* @param channel the raw channel
* @param channel the raw channel
* @param socketConfig the socket config
* @return the server channel
* @throws IOException related to the creation of the channel
*/
public abstract ServerSocket createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException;
public abstract ServerSocket createServerChannel(NioSelector selector, ServerSocketChannel channel, Config.ServerSocket socketConfig)
throws IOException;
private Socket internalCreateChannel(NioSelector selector, SocketChannel rawChannel) throws IOException {
protected InetSocketAddress getRemoteAddress(SocketChannel rawChannel) throws IOException {
InetSocketAddress remoteAddress = (InetSocketAddress) rawChannel.socket().getRemoteSocketAddress();
if (remoteAddress == null) {
throw new IOException("Accepted socket does not have remote address");
}
return remoteAddress;
}
private Socket internalCreateChannel(NioSelector selector, SocketChannel rawChannel, Config.Socket config) throws IOException {
try {
Socket channel = createChannel(selector, rawChannel);
Socket channel = createChannel(selector, rawChannel, config);
assert channel.getContext() != null : "channel context should have been set on channel";
return channel;
} catch (UncheckedIOException e) {
@ -114,9 +136,10 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
}
}
private ServerSocket internalCreateServerChannel(NioSelector selector, ServerSocketChannel rawChannel) throws IOException {
private ServerSocket internalCreateServerChannel(NioSelector selector, ServerSocketChannel rawChannel, Config.ServerSocket config)
throws IOException {
try {
return createServerChannel(selector, rawChannel);
return createServerChannel(selector, rawChannel, config);
} catch (Exception e) {
closeRawChannel(rawChannel, e);
throw e;
@ -141,6 +164,15 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
}
}
private void setNonBlocking(AbstractSelectableChannel rawChannel) throws IOException {
try {
rawChannel.configureBlocking(false);
} catch (IOException e) {
closeRawChannel(rawChannel, e);
throw e;
}
}
private static void closeRawChannel(Closeable c, Exception e) {
try {
c.close();
@ -149,107 +181,19 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
}
}
private Config.Socket createSocketConfig(InetSocketAddress remoteAddress, boolean isAccepted) {
return new Config.Socket(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, remoteAddress,
isAccepted);
}
public static class RawChannelFactory {
private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
private final boolean tcpReusedAddress;
private final int tcpSendBufferSize;
private final int tcpReceiveBufferSize;
public RawChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReusedAddress, int tcpSendBufferSize,
int tcpReceiveBufferSize) {
this.tcpNoDelay = tcpNoDelay;
this.tcpKeepAlive = tcpKeepAlive;
this.tcpReusedAddress = tcpReusedAddress;
this.tcpSendBufferSize = tcpSendBufferSize;
this.tcpReceiveBufferSize = tcpReceiveBufferSize;
SocketChannel openNioChannel() throws IOException {
return SocketChannel.open();
}
SocketChannel openNioChannel(InetSocketAddress remoteAddress) throws IOException {
SocketChannel socketChannel = SocketChannel.open();
try {
configureSocketChannel(socketChannel);
connect(socketChannel, remoteAddress);
} catch (IOException e) {
closeRawChannel(socketChannel, e);
throw e;
}
return socketChannel;
}
SocketChannel acceptNioChannel(ServerChannelContext serverContext) throws IOException {
ServerSocketChannel rawChannel = serverContext.getChannel().getRawChannel();
assert rawChannel.isBlocking() == false;
SocketChannel socketChannel = accept(rawChannel);
assert rawChannel.isBlocking() == false;
if (socketChannel == null) {
return null;
}
try {
configureSocketChannel(socketChannel);
} catch (IOException e) {
closeRawChannel(socketChannel, e);
throw e;
}
return socketChannel;
}
ServerSocketChannel openNioServerSocketChannel(InetSocketAddress address) throws IOException {
ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.configureBlocking(false);
java.net.ServerSocket socket = serverSocketChannel.socket();
try {
socket.setReuseAddress(tcpReusedAddress);
serverSocketChannel.bind(address);
} catch (IOException e) {
closeRawChannel(serverSocketChannel, e);
throw e;
}
return serverSocketChannel;
}
private static final boolean MAC_OS_X = System.getProperty("os.name").startsWith("Mac OS X");
private static void setSocketOption(CheckedRunnable<SocketException> runnable) throws SocketException {
try {
runnable.run();
} catch (SocketException e) {
if (MAC_OS_X == false) {
// ignore on Mac, see https://github.com/elastic/elasticsearch/issues/41071
throw e;
}
}
}
private void configureSocketChannel(SocketChannel channel) throws IOException {
channel.configureBlocking(false);
java.net.Socket socket = channel.socket();
setSocketOption(() -> socket.setTcpNoDelay(tcpNoDelay));
setSocketOption(() -> socket.setKeepAlive(tcpKeepAlive));
setSocketOption(() -> socket.setReuseAddress(tcpReusedAddress));
if (tcpSendBufferSize > 0) {
setSocketOption(() -> socket.setSendBufferSize(tcpSendBufferSize));
}
if (tcpReceiveBufferSize > 0) {
setSocketOption(() -> socket.setSendBufferSize(tcpReceiveBufferSize));
}
}
public static SocketChannel accept(ServerSocketChannel serverSocketChannel) throws IOException {
try {
return AccessController.doPrivileged((PrivilegedExceptionAction<SocketChannel>) serverSocketChannel::accept);
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}
private static void connect(SocketChannel socketChannel, InetSocketAddress remoteAddress) throws IOException {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Boolean>) () -> socketChannel.connect(remoteAddress));
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
ServerSocketChannel openNioServerSocketChannel() throws IOException {
return ServerSocketChannel.open();
}
}
}

View File

@ -0,0 +1,94 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.net.InetSocketAddress;
public abstract class Config {
private final boolean tcpReuseAddress;
public Config(boolean tcpReuseAddress) {
this.tcpReuseAddress = tcpReuseAddress;
}
public boolean tcpReuseAddress() {
return tcpReuseAddress;
}
public static class Socket extends Config {
private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
private final int tcpSendBufferSize;
private final int tcpReceiveBufferSize;
private final InetSocketAddress remoteAddress;
private final boolean isAccepted;
public Socket(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReuseAddress, int tcpSendBufferSize, int tcpReceiveBufferSize,
InetSocketAddress remoteAddress, boolean isAccepted) {
super(tcpReuseAddress);
this.tcpNoDelay = tcpNoDelay;
this.tcpKeepAlive = tcpKeepAlive;
this.tcpSendBufferSize = tcpSendBufferSize;
this.tcpReceiveBufferSize = tcpReceiveBufferSize;
this.remoteAddress = remoteAddress;
this.isAccepted = isAccepted;
}
public boolean tcpNoDelay() {
return tcpNoDelay;
}
public boolean tcpKeepAlive() {
return tcpKeepAlive;
}
public int tcpSendBufferSize() {
return tcpSendBufferSize;
}
public int tcpReceiveBufferSize() {
return tcpReceiveBufferSize;
}
public boolean isAccepted() {
return isAccepted;
}
public InetSocketAddress getRemoteAddress() {
return remoteAddress;
}
}
public static class ServerSocket extends Config {
private InetSocketAddress localAddress;
public ServerSocket(boolean tcpReuseAddress, InetSocketAddress localAddress) {
super(tcpReuseAddress);
this.localAddress = localAddress;
}
public InetSocketAddress getLocalAddress() {
return localAddress;
}
}
}

View File

@ -22,6 +22,7 @@ package org.elasticsearch.nio;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
public class NioServerSocketChannel extends NioChannel {
@ -32,7 +33,6 @@ public class NioServerSocketChannel extends NioChannel {
public NioServerSocketChannel(ServerSocketChannel serverSocketChannel) {
this.serverSocketChannel = serverSocketChannel;
attemptToSetLocalAddress();
}
/**
@ -49,6 +49,10 @@ public class NioServerSocketChannel extends NioChannel {
}
}
public void addBindListener(BiConsumer<Void, Exception> listener) {
context.addBindListener(listener);
}
@Override
public InetSocketAddress getLocalAddress() {
attemptToSetLocalAddress();

View File

@ -19,8 +19,6 @@
package org.elasticsearch.nio;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.nio.channels.SocketChannel;
import java.util.concurrent.atomic.AtomicBoolean;
@ -30,17 +28,12 @@ public class NioSocketChannel extends NioChannel {
private final AtomicBoolean contextSet = new AtomicBoolean(false);
private final SocketChannel socketChannel;
private final InetSocketAddress remoteAddress;
private volatile InetSocketAddress remoteAddress;
private volatile InetSocketAddress localAddress;
private SocketChannelContext context;
private volatile SocketChannelContext context;
public NioSocketChannel(SocketChannel socketChannel) {
this.socketChannel = socketChannel;
try {
this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
public void setContext(SocketChannelContext context) {
@ -70,6 +63,9 @@ public class NioSocketChannel extends NioChannel {
}
public InetSocketAddress getRemoteAddress() {
if (remoteAddress == null) {
remoteAddress = (InetSocketAddress) socketChannel.socket().getRemoteSocketAddress();
}
return remoteAddress;
}
@ -81,7 +77,7 @@ public class NioSocketChannel extends NioChannel {
public String toString() {
return "NioSocketChannel{" +
"localAddress=" + getLocalAddress() +
", remoteAddress=" + remoteAddress +
", remoteAddress=" + getRemoteAddress() +
'}';
}
}

View File

@ -19,9 +19,18 @@
package org.elasticsearch.nio;
import org.elasticsearch.common.concurrent.CompletableContext;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
@ -29,23 +38,49 @@ public class ServerChannelContext extends ChannelContext<ServerSocketChannel> {
private final NioServerSocketChannel channel;
private final NioSelector selector;
private final Config.ServerSocket config;
private final Consumer<NioSocketChannel> acceptor;
private final AtomicBoolean isClosing = new AtomicBoolean(false);
private final ChannelFactory<?, ?> channelFactory;
private final CompletableContext<Void> bindContext = new CompletableContext<>();
public ServerChannelContext(NioServerSocketChannel channel, ChannelFactory<?, ?> channelFactory, NioSelector selector,
Consumer<NioSocketChannel> acceptor, Consumer<Exception> exceptionHandler) {
Config.ServerSocket config, Consumer<NioSocketChannel> acceptor,
Consumer<Exception> exceptionHandler) {
super(channel.getRawChannel(), exceptionHandler);
this.channel = channel;
this.channelFactory = channelFactory;
this.selector = selector;
this.config = config;
this.acceptor = acceptor;
}
public void acceptChannels(Supplier<NioSelector> selectorSupplier) throws IOException {
NioSocketChannel acceptedChannel;
while ((acceptedChannel = channelFactory.acceptNioChannel(this, selectorSupplier)) != null) {
acceptor.accept(acceptedChannel);
SocketChannel acceptedChannel;
while ((acceptedChannel = accept(rawChannel)) != null) {
NioSocketChannel nioChannel = channelFactory.acceptNioChannel(acceptedChannel, selectorSupplier);
acceptor.accept(nioChannel);
}
}
public void addBindListener(BiConsumer<Void, Exception> listener) {
bindContext.addListener(listener);
}
@Override
protected void register() throws IOException {
super.register();
configureSocket(rawChannel.socket());
InetSocketAddress localAddress = config.getLocalAddress();
try {
rawChannel.bind(localAddress);
bindContext.complete(null);
} catch (IOException e) {
IOException exception = new IOException("Failed to bind server socket channel {localAddress=" + localAddress + "}.", e);
bindContext.completeExceptionally(exception);
throw exception;
}
}
@ -66,4 +101,18 @@ public class ServerChannelContext extends ChannelContext<ServerSocketChannel> {
return channel;
}
private void configureSocket(ServerSocket socket) throws IOException {
socket.setReuseAddress(config.tcpReuseAddress());
}
protected static SocketChannel accept(ServerSocketChannel serverSocketChannel) throws IOException {
try {
assert serverSocketChannel.isBlocking() == false;
SocketChannel channel = AccessController.doPrivileged((PrivilegedExceptionAction<SocketChannel>) serverSocketChannel::accept);
assert serverSocketChannel.isBlocking() == false;
return channel;
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}
}

View File

@ -24,9 +24,14 @@ import org.elasticsearch.nio.utils.ByteBufferUtils;
import org.elasticsearch.nio.utils.ExceptionsHelper;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicBoolean;
@ -47,19 +52,23 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
protected final NioSocketChannel channel;
protected final InboundChannelBuffer channelBuffer;
protected final AtomicBoolean isClosing = new AtomicBoolean(false);
private final NioChannelHandler readWriteHandler;
private final NioChannelHandler channelHandler;
private final NioSelector selector;
private final Config.Socket socketConfig;
private final CompletableContext<Void> connectContext = new CompletableContext<>();
private final LinkedList<FlushOperation> pendingFlushes = new LinkedList<>();
private boolean closeNow;
private boolean socketOptionsSet;
private Exception connectException;
protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig,
Consumer<Exception> exceptionHandler, NioChannelHandler channelHandler,
InboundChannelBuffer channelBuffer) {
super(channel.getRawChannel(), exceptionHandler);
this.selector = selector;
this.channel = channel;
this.readWriteHandler = readWriteHandler;
this.socketConfig = socketConfig;
this.channelHandler = channelHandler;
this.channelBuffer = channelBuffer;
}
@ -73,6 +82,22 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
return channel;
}
@Override
protected void register() throws IOException {
super.register();
configureSocket(rawChannel.socket(), false);
if (socketConfig.isAccepted() == false) {
InetSocketAddress remoteAddress = socketConfig.getRemoteAddress();
try {
connect(rawChannel, remoteAddress);
} catch (IOException e) {
throw new IOException("Failed to initiate socket channel connection {remoteAddress=" + remoteAddress + "}.", e);
}
}
}
public void addConnectListener(BiConsumer<Void, Exception> listener) {
connectContext.addListener(listener);
}
@ -117,6 +142,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
}
if (isConnected) {
connectContext.complete(null);
configureSocket(rawChannel.socket(), true);
}
return isConnected;
}
@ -127,14 +153,14 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
return;
}
WriteOperation writeOperation = readWriteHandler.createWriteOperation(this, message, listener);
WriteOperation writeOperation = channelHandler.createWriteOperation(this, message, listener);
getSelector().queueWrite(writeOperation);
}
public void queueWriteOperation(WriteOperation writeOperation) {
getSelector().assertOnSelectorThread();
pendingFlushes.addAll(readWriteHandler.writeToBytes(writeOperation));
pendingFlushes.addAll(channelHandler.writeToBytes(writeOperation));
}
public abstract int read() throws IOException;
@ -157,7 +183,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
@Override
protected void channelActive() throws IOException {
readWriteHandler.channelActive();
channelHandler.channelActive();
}
@Override
@ -174,14 +200,14 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
isClosing.set(true);
// Poll for new flush operations to close
pendingFlushes.addAll(readWriteHandler.pollFlushOperations());
pendingFlushes.addAll(channelHandler.pollFlushOperations());
FlushOperation flushOperation;
while ((flushOperation = pendingFlushes.pollFirst()) != null) {
selector.executeFailedListener(flushOperation.getListener(), new ClosedChannelException());
}
try {
readWriteHandler.close();
channelHandler.close();
} catch (IOException e) {
closingExceptions.add(e);
}
@ -196,12 +222,12 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
protected void handleReadBytes() throws IOException {
int bytesConsumed = Integer.MAX_VALUE;
while (isOpen() && bytesConsumed > 0 && channelBuffer.getIndex() > 0) {
bytesConsumed = readWriteHandler.consumeReads(channelBuffer);
bytesConsumed = channelHandler.consumeReads(channelBuffer);
channelBuffer.release(bytesConsumed);
}
// Some protocols might produce messages to flush during a read operation.
pendingFlushes.addAll(readWriteHandler.pollFlushOperations());
pendingFlushes.addAll(channelHandler.pollFlushOperations());
}
public boolean readyForFlush() {
@ -217,7 +243,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
public abstract boolean selectorShouldClose();
protected boolean closeNow() {
return closeNow || readWriteHandler.closeNow();
return closeNow || channelHandler.closeNow();
}
protected void setCloseNow() {
@ -288,4 +314,42 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
}
return totalBytesFlushed;
}
private void configureSocket(Socket socket, boolean isConnectComplete) throws IOException {
if (socketOptionsSet) {
return;
}
try {
// Set reuse address first as it must be set before a bind call. Some implementations throw
// exceptions on other socket options if the channel is not connected. But setting reuse first,
// we ensure that it is properly set before any bind attempt.
socket.setReuseAddress(socketConfig.tcpReuseAddress());
socket.setKeepAlive(socketConfig.tcpKeepAlive());
socket.setTcpNoDelay(socketConfig.tcpNoDelay());
int tcpSendBufferSize = socketConfig.tcpSendBufferSize();
if (tcpSendBufferSize > 0) {
socket.setSendBufferSize(tcpSendBufferSize);
}
int tcpReceiveBufferSize = socketConfig.tcpReceiveBufferSize();
if (tcpReceiveBufferSize > 0) {
socket.setReceiveBufferSize(tcpReceiveBufferSize);
}
socketOptionsSet = true;
} catch (IOException e) {
if (isConnectComplete) {
throw e;
}
// Ignore if not connect complete. Some implementations fail on setting socket options if the
// socket is not connected. We will try again after connection.
}
}
private static void connect(SocketChannel socketChannel, InetSocketAddress remoteAddress) throws IOException {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Boolean>) () -> socketChannel.connect(remoteAddress));
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}
}

View File

@ -62,7 +62,7 @@ public class BytesChannelContextTests extends ESTestCase {
channelBuffer = InboundChannelBuffer.allocatingInstance();
TestReadWriteHandler handler = new TestReadWriteHandler(readConsumer);
when(channel.getRawChannel()).thenReturn(rawChannel);
context = new BytesChannelContext(channel, selector, mock(Consumer.class), handler, channelBuffer);
context = new BytesChannelContext(channel, selector, mock(Config.Socket.class), mock(Consumer.class), handler, channelBuffer);
when(selector.isOnCurrentThread()).thenReturn(true);
ByteBuffer buffer = ByteBuffer.allocate(1 << 14);

View File

@ -31,7 +31,6 @@ import java.nio.channels.SocketChannel;
import java.util.function.Supplier;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -69,10 +68,7 @@ public class ChannelFactoryTests extends ESTestCase {
}
public void testAcceptChannel() throws IOException {
ServerChannelContext serverChannelContext = mock(ServerChannelContext.class);
when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel);
NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier);
NioSocketChannel channel = channelFactory.acceptNioChannel(rawChannel, socketSelectorSupplier);
verify(socketSelector).scheduleForRegistration(channel);
@ -80,18 +76,16 @@ public class ChannelFactoryTests extends ESTestCase {
}
public void testAcceptedChannelRejected() throws IOException {
ServerChannelContext serverChannelContext = mock(ServerChannelContext.class);
when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel);
doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any());
expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier));
expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(rawChannel, socketSelectorSupplier));
assertFalse(rawChannel.isOpen());
}
public void testOpenChannel() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class);
when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel);
when(rawChannelFactory.openNioChannel()).thenReturn(rawChannel);
NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelectorSupplier);
@ -102,7 +96,7 @@ public class ChannelFactoryTests extends ESTestCase {
public void testOpenedChannelRejected() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class);
when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel);
when(rawChannelFactory.openNioChannel()).thenReturn(rawChannel);
doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any());
expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelectorSupplier));
@ -112,7 +106,7 @@ public class ChannelFactoryTests extends ESTestCase {
public void testOpenServerChannel() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class);
when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel);
when(rawChannelFactory.openNioServerSocketChannel()).thenReturn(rawServerChannel);
NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier);
@ -123,7 +117,7 @@ public class ChannelFactoryTests extends ESTestCase {
public void testOpenedServerChannelRejected() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class);
when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel);
when(rawChannelFactory.openNioServerSocketChannel()).thenReturn(rawServerChannel);
doThrow(new IllegalStateException()).when(acceptingSelector).scheduleForRegistration(any());
expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier));
@ -134,19 +128,26 @@ public class ChannelFactoryTests extends ESTestCase {
private static class TestChannelFactory extends ChannelFactory<NioServerSocketChannel, NioSocketChannel> {
TestChannelFactory(RawChannelFactory rawChannelFactory) {
super(rawChannelFactory);
super(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, rawChannelFactory);
}
@Override
public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
NioSocketChannel nioSocketChannel = new NioSocketChannel(channel);
nioSocketChannel.setContext(mock(SocketChannelContext.class));
return nioSocketChannel;
}
@Override
public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException {
public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
return new NioServerSocketChannel(channel);
}
@Override
protected InetSocketAddress getRemoteAddress(SocketChannel rawChannel) throws IOException {
// Override this functionality to avoid having to connect the accepted channel
return mock(InetSocketAddress.class);
}
}
}

View File

@ -23,7 +23,9 @@ import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel;
@ -33,7 +35,6 @@ import java.util.Collections;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -67,6 +68,7 @@ public class EventHandlerTests extends ESTestCase {
SocketChannel rawChannel = mock(SocketChannel.class);
when(rawChannel.finishConnect()).thenReturn(true);
NioSocketChannel channel = new NioSocketChannel(rawChannel);
when(rawChannel.socket()).thenReturn(mock(Socket.class));
context = new DoNotRegisterSocketContext(channel, selector, channelExceptionHandler, readWriteHandler);
channel.setContext(context);
handler.handleRegistration(context);
@ -104,22 +106,8 @@ public class EventHandlerTests extends ESTestCase {
assertEquals(SelectionKey.OP_ACCEPT, serverContext.getSelectionKey().interestOps());
}
public void testHandleAcceptCallsChannelFactory() throws IOException {
NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class));
NioSocketChannel nullChannel = null;
when(channelFactory.acceptNioChannel(same(serverContext), same(selectorSupplier))).thenReturn(childChannel, nullChannel);
handler.acceptChannel(serverContext);
verify(channelFactory, times(2)).acceptNioChannel(same(serverContext), same(selectorSupplier));
}
public void testHandleAcceptCallsServerAcceptCallback() throws IOException {
NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class));
SocketChannelContext childContext = mock(SocketChannelContext.class);
childChannel.setContext(childContext);
public void testHandleAcceptAccept() throws IOException {
ServerChannelContext serverChannelContext = mock(ServerChannelContext.class);
when(channelFactory.acceptNioChannel(same(serverContext), same(selectorSupplier))).thenReturn(childChannel);
handler.acceptChannel(serverChannelContext);
@ -254,7 +242,7 @@ public class EventHandlerTests extends ESTestCase {
DoNotRegisterSocketContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
NioChannelHandler handler) {
super(channel, selector, exceptionHandler, handler, InboundChannelBuffer.allocatingInstance());
super(channel, selector, getSocketConfig(), exceptionHandler, handler, InboundChannelBuffer.allocatingInstance());
}
@Override
@ -270,7 +258,7 @@ public class EventHandlerTests extends ESTestCase {
@SuppressWarnings("unchecked")
DoNotRegisterServerContext(NioServerSocketChannel channel, NioSelector selector, Consumer<NioSocketChannel> acceptor) {
super(channel, channelFactory, selector, acceptor, mock(Consumer.class));
super(channel, channelFactory, selector, getServerSocketConfig(), acceptor, mock(Consumer.class));
}
@Override
@ -280,4 +268,13 @@ public class EventHandlerTests extends ESTestCase {
selectionKey.attach(this);
}
}
private static Config.ServerSocket getServerSocketConfig() {
return new Config.ServerSocket(randomBoolean(), mock(InetSocketAddress.class));
}
private static Config.Socket getSocketConfig() {
return new Config.Socket(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, mock(InetSocketAddress.class),
randomBoolean());
}
}

View File

@ -19,12 +19,16 @@
package org.elasticsearch.nio;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
@ -40,11 +44,13 @@ import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.isNull;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@SuppressForbidden(reason = "allow call to socket connect")
public class SocketChannelContextTests extends ESTestCase {
private SocketChannel rawChannel;
@ -55,6 +61,7 @@ public class SocketChannelContextTests extends ESTestCase {
private NioSelector selector;
private NioChannelHandler handler;
private ByteBuffer ioBuffer = ByteBuffer.allocate(1024);
private Socket rawSocket;
@SuppressWarnings("unchecked")
@Before
@ -76,6 +83,8 @@ public class SocketChannelContextTests extends ESTestCase {
ioBuffer.clear();
return ioBuffer;
});
rawSocket = mock(Socket.class);
when(rawChannel.socket()).thenReturn(rawSocket);
}
public void testIOExceptionSetIfEncountered() throws IOException {
@ -101,6 +110,31 @@ public class SocketChannelContextTests extends ESTestCase {
assertTrue(context.closeNow());
}
public void testRegisterInitiatesConnect() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class);
boolean isAccepted = randomBoolean();
Config.Socket config;
boolean tcpNoDelay = randomBoolean();
boolean tcpKeepAlive = randomBoolean();
boolean tcpReuseAddress = randomBoolean();
int tcpSendBufferSize = randomIntBetween(1000, 2000);
int tcpReceiveBufferSize = randomIntBetween(1000, 2000);
config = new Config.Socket(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, address, isAccepted);
InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance();
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, buffer, config);
context.register();
if (isAccepted) {
verify(rawChannel, times(0)).connect(any(InetSocketAddress.class));
} else {
verify(rawChannel).connect(same(address));
}
verify(rawSocket).setTcpNoDelay(tcpNoDelay);
verify(rawSocket).setKeepAlive(tcpKeepAlive);
verify(rawSocket).setReuseAddress(tcpReuseAddress);
verify(rawSocket).setSendBufferSize(tcpSendBufferSize);
verify(rawSocket).setReceiveBufferSize(tcpReceiveBufferSize);
}
public void testConnectSucceeds() throws IOException {
AtomicBoolean listenerCalled = new AtomicBoolean(false);
when(rawChannel.finishConnect()).thenReturn(false, true);
@ -142,6 +176,29 @@ public class SocketChannelContextTests extends ESTestCase {
assertSame(ioException, exception.get());
}
public void testConnectCanSetSocketOptions() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class);
Config.Socket config;
boolean tcpNoDelay = randomBoolean();
boolean tcpKeepAlive = randomBoolean();
boolean tcpReuseAddress = randomBoolean();
int tcpSendBufferSize = randomIntBetween(1000, 2000);
int tcpReceiveBufferSize = randomIntBetween(1000, 2000);
config = new Config.Socket(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, address, false);
InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance();
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, buffer, config);
doThrow(new SocketException()).doNothing().when(rawSocket).setReuseAddress(tcpReuseAddress);
context.register();
when(rawChannel.finishConnect()).thenReturn(true);
context.connect();
verify(rawSocket, times(2)).setReuseAddress(tcpReuseAddress);
verify(rawSocket).setKeepAlive(tcpKeepAlive);
verify(rawSocket).setTcpNoDelay(tcpNoDelay);
verify(rawSocket).setSendBufferSize(tcpSendBufferSize);
verify(rawSocket).setReceiveBufferSize(tcpReceiveBufferSize);
}
public void testChannelActiveCallsHandler() throws IOException {
context.channelActive();
verify(handler).channelActive();
@ -262,7 +319,8 @@ public class SocketChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(realChannel);
when(channel.isOpen()).thenReturn(true);
InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance();
BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, handler, buffer);
BytesChannelContext context = new BytesChannelContext(channel, selector, mock(Config.Socket.class), exceptionHandler, handler,
buffer);
context.closeFromSelector();
verify(handler).close();
}
@ -379,11 +437,21 @@ public class SocketChannelContextTests extends ESTestCase {
assertEquals(1, flushOperation.getBuffersToWrite()[0].position());
}
private static Config.Socket getSocketConfig() {
return new Config.Socket(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, mock(InetSocketAddress.class),
randomBoolean());
}
private static class TestSocketChannelContext extends SocketChannelContext {
private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
this(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, getSocketConfig());
}
private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer, Config.Socket config) {
super(channel, selector, config, exceptionHandler, readWriteHandler, channelBuffer);
}
@Override

View File

@ -29,20 +29,13 @@ import java.net.InetSocketAddress;
public class Netty4TcpServerChannel implements TcpServerChannel {
private final Channel channel;
private final String profile;
private final CompletableContext<Void> closeContext = new CompletableContext<>();
Netty4TcpServerChannel(Channel channel, String profile) {
Netty4TcpServerChannel(Channel channel) {
this.channel = channel;
this.profile = profile;
Netty4TcpChannel.addListener(this.channel.closeFuture(), closeContext);
}
@Override
public String getProfile() {
return profile;
}
@Override
public InetSocketAddress getLocalAddress() {
return (InetSocketAddress) channel.localAddress();

View File

@ -239,7 +239,7 @@ public class Netty4Transport extends TcpTransport {
@Override
protected Netty4TcpServerChannel bind(String name, InetSocketAddress address) {
Channel channel = serverBootstraps.get(name).bind(address).syncUninterruptibly().channel();
Netty4TcpServerChannel esChannel = new Netty4TcpServerChannel(channel, name);
Netty4TcpServerChannel esChannel = new Netty4TcpServerChannel(channel);
channel.attr(SERVER_CHANNEL_KEY).set(esChannel);
return esChannel;
}

View File

@ -22,6 +22,8 @@ package org.elasticsearch.http.nio;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
@ -33,6 +35,7 @@ import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpServerChannel;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
@ -130,7 +133,11 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
@Override
protected HttpServerChannel bind(InetSocketAddress socketAddress) throws IOException {
return nioGroup.bindServerChannel(socketAddress, channelFactory);
NioHttpServerChannel httpServerChannel = nioGroup.bindServerChannel(socketAddress, channelFactory);
PlainActionFuture<Void> future = PlainActionFuture.newFuture();
httpServerChannel.addBindListener(ActionListener.toBiConsumer(future));
future.actionGet();
return httpServerChannel;
}
protected ChannelFactory<NioHttpServerChannel, NioHttpChannel> channelFactory() {
@ -144,27 +151,29 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
private class HttpChannelFactory extends ChannelFactory<NioHttpServerChannel, NioHttpChannel> {
private HttpChannelFactory() {
super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize));
super(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize);
}
@Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
HttpReadWriteHandler handler = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
Consumer<Exception> exceptionHandler = (e) -> onException(httpChannel, e);
SocketChannelContext context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpReadWritePipeline,
SocketChannelContext context = new BytesChannelContext(httpChannel, selector, socketConfig, exceptionHandler, handler,
new InboundChannelBuffer(pageAllocator));
httpChannel.setContext(context);
return httpChannel;
}
@Override
public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException {
public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
NioHttpServerChannel httpServerChannel = new NioHttpServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> onServerException(httpServerChannel, e);
Consumer<NioSocketChannel> acceptor = NioHttpServerTransport.this::acceptChannel;
ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, acceptor, exceptionHandler);
ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, socketConfig, acceptor,
exceptionHandler);
httpServerChannel.setContext(context);
return httpServerChannel;
}

View File

@ -31,22 +31,14 @@ import java.nio.channels.ServerSocketChannel;
*/
public class NioTcpServerChannel extends NioServerSocketChannel implements TcpServerChannel {
private final String profile;
public NioTcpServerChannel(String profile, ServerSocketChannel socketChannel) {
public NioTcpServerChannel(ServerSocketChannel socketChannel) {
super(socketChannel);
this.profile = profile;
}
public void close() {
getContext().closeChannel();
}
@Override
public String getProfile() {
return profile;
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
addCloseListener(ActionListener.toBiConsumer(listener));

View File

@ -23,6 +23,8 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
@ -31,6 +33,7 @@ import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
@ -38,6 +41,7 @@ import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportSettings;
import java.io.IOException;
import java.net.InetSocketAddress;
@ -70,7 +74,11 @@ public class NioTransport extends TcpTransport {
@Override
protected NioTcpServerChannel bind(String name, InetSocketAddress address) throws IOException {
TcpChannelFactory channelFactory = this.profileToChannelFactory.get(name);
return nioGroup.bindServerChannel(address, channelFactory);
NioTcpServerChannel serverChannel = nioGroup.bindServerChannel(address, channelFactory);
PlainActionFuture<Void> future = PlainActionFuture.newFuture();
serverChannel.addBindListener(ActionListener.toBiConsumer(future));
future.actionGet();
return serverChannel;
}
@Override
@ -85,7 +93,7 @@ public class NioTransport extends TcpTransport {
try {
nioGroup = groupFactory.getTransportGroup();
ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
ProfileSettings clientProfileSettings = new ProfileSettings(settings, TransportSettings.DEFAULT_PROFILE);
clientChannelFactory = clientChannelFactoryFunction(clientProfileSettings);
if (NetworkService.NETWORK_SERVER.get(settings)) {
@ -133,8 +141,9 @@ public class NioTransport extends TcpTransport {
protected abstract class TcpChannelFactory extends ChannelFactory<NioTcpServerChannel, NioTcpChannel> {
protected TcpChannelFactory(RawChannelFactory rawChannelFactory) {
super(rawChannelFactory);
protected TcpChannelFactory(ProfileSettings profileSettings) {
super(profileSettings.tcpNoDelay, profileSettings.tcpKeepAlive, profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()), Math.toIntExact(profileSettings.receiveBufferSize.getBytes()));
}
}
@ -144,32 +153,29 @@ public class NioTransport extends TcpTransport {
private final String profileName;
private TcpChannelFactoryImpl(ProfileSettings profileSettings, boolean isClient) {
super(new RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive,
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())));
super(profileSettings);
this.isClient = isClient;
this.profileName = profileSettings.profileName;
}
@Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) {
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this);
TcpReadWriteHandler handler = new TcpReadWriteHandler(nioChannel, NioTransport.this);
Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler,
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, socketConfig, exceptionHandler, handler,
new InboundChannelBuffer(pageAllocator));
nioChannel.setContext(context);
return nioChannel;
}
@Override
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) {
NioTcpServerChannel nioChannel = new NioTcpServerChannel(profileName, channel);
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
NioTcpServerChannel nioChannel = new NioTcpServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> onServerException(nioChannel, e);
Consumer<NioSocketChannel> acceptor = NioTransport.this::acceptChannel;
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, acceptor, exceptionHandler);
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, socketConfig, acceptor, exceptionHandler);
nioChannel.setContext(context);
return nioChannel;
}

View File

@ -31,8 +31,8 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.network.NetworkService;
@ -42,14 +42,15 @@ import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.EventHandler;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelectorGroup;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSelectorGroup;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.tasks.Task;
@ -149,7 +150,8 @@ class NioHttpClient implements Closeable {
connectFuture.actionGet();
for (HttpRequest request : requests) {
nioSocketChannel.getContext().sendMessage(request, (v, e) -> {});
nioSocketChannel.getContext().sendMessage(request, (v, e) -> {
});
}
if (latch.await(30L, TimeUnit.SECONDS) == false) {
fail("Failed to get all expected responses.");
@ -177,17 +179,17 @@ class NioHttpClient implements Closeable {
private final Collection<FullHttpResponse> content;
private ClientChannelFactory(CountDownLatch latch, Collection<FullHttpResponse> content) {
super(new RawChannelFactory(NetworkService.TCP_NO_DELAY.get(Settings.EMPTY),
super(NetworkService.TCP_NO_DELAY.get(Settings.EMPTY),
NetworkService.TCP_KEEP_ALIVE.get(Settings.EMPTY),
NetworkService.TCP_REUSE_ADDRESS.get(Settings.EMPTY),
Math.toIntExact(NetworkService.TCP_SEND_BUFFER_SIZE.get(Settings.EMPTY).getBytes()),
Math.toIntExact(NetworkService.TCP_RECEIVE_BUFFER_SIZE.get(Settings.EMPTY).getBytes())));
Math.toIntExact(NetworkService.TCP_RECEIVE_BUFFER_SIZE.get(Settings.EMPTY).getBytes()));
this.latch = latch;
this.content = content;
}
@Override
public NioSocketChannel createChannel(NioSelector selector, java.nio.channels.SocketChannel channel) throws IOException {
public NioSocketChannel createChannel(NioSelector selector, java.nio.channels.SocketChannel channel, Config.Socket socketConfig) {
NioSocketChannel nioSocketChannel = new NioSocketChannel(channel);
HttpClientHandler handler = new HttpClientHandler(nioSocketChannel, latch, content);
Consumer<Exception> exceptionHandler = (e) -> {
@ -195,14 +197,15 @@ class NioHttpClient implements Closeable {
onException(e);
nioSocketChannel.close();
};
SocketChannelContext context = new BytesChannelContext(nioSocketChannel, selector, exceptionHandler, handler,
SocketChannelContext context = new BytesChannelContext(nioSocketChannel, selector, socketConfig, exceptionHandler, handler,
InboundChannelBuffer.allocatingInstance());
nioSocketChannel.setContext(context);
return nioSocketChannel;
}
@Override
public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) {
public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
throw new UnsupportedOperationException("Cannot create server channel");
}
}

View File

@ -21,6 +21,7 @@ package org.elasticsearch.transport.nio;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioServerSocketChannel;
@ -57,20 +58,21 @@ public class NioGroupFactoryTests extends ESTestCase {
private static class BindingFactory extends ChannelFactory<NioServerSocketChannel, NioSocketChannel> {
private BindingFactory() {
super(new ChannelFactory.RawChannelFactory(false, false, false, -1, -1));
super(false, false, false, -1, -1);
}
@Override
public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException {
throw new IOException("boom");
}
@Override
public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException {
public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
NioServerSocketChannel nioChannel = new NioServerSocketChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> {};
Consumer<NioSocketChannel> acceptor = (c) -> {};
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, acceptor, exceptionHandler);
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, socketConfig, acceptor, exceptionHandler);
nioChannel.setContext(context);
return nioChannel;
}

View File

@ -31,11 +31,6 @@ import java.net.InetSocketAddress;
*/
public interface TcpServerChannel extends CloseableChannel {
/**
* This returns the profile for this channel.
*/
String getProfile();
/**
* Returns the local address for this channel.
*

View File

@ -25,6 +25,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@ -37,6 +38,7 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.BytesWriteHandler;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSelectorGroup;
@ -91,7 +93,11 @@ public class MockNioTransport extends TcpTransport {
@Override
protected MockServerChannel bind(String name, InetSocketAddress address) throws IOException {
MockTcpChannelFactory channelFactory = this.profileToChannelFactory.get(name);
return nioGroup.bindServerChannel(address, channelFactory);
MockServerChannel serverChannel = nioGroup.bindServerChannel(address, channelFactory);
PlainActionFuture<Void> future = PlainActionFuture.newFuture();
serverChannel.addBindListener(ActionListener.toBiConsumer(future));
future.actionGet();
return serverChannel;
}
@Override
@ -190,17 +196,17 @@ public class MockNioTransport extends TcpTransport {
private final String profileName;
private MockTcpChannelFactory(boolean isClient, ProfileSettings profileSettings, String profileName) {
super(new RawChannelFactory(profileSettings.tcpNoDelay,
super(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive,
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())));
Math.toIntExact(profileSettings.receiveBufferSize.getBytes()));
this.isClient = isClient;
this.profileName = profileName;
}
@Override
public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel);
IntFunction<Page> pageSupplier = (length) -> {
if (length > PageCacheRecycler.BYTE_PAGE_SIZE) {
@ -211,7 +217,7 @@ public class MockNioTransport extends TcpTransport {
}
};
MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this);
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e),
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, socketConfig, (e) -> exceptionCaught(nioChannel, e),
readWriteHandler, new InboundChannelBuffer(pageSupplier));
nioChannel.setContext(context);
nioChannel.addConnectListener((v, e) -> {
@ -229,18 +235,19 @@ public class MockNioTransport extends TcpTransport {
}
@Override
public MockServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) {
MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel);
public MockServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, Config.ServerSocket socketConfig) {
MockServerChannel nioServerChannel = new MockServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> logger.error(() ->
new ParameterizedMessage("exception from server channel caught on transport layer [{}]", channel), e);
ServerChannelContext context = new ServerChannelContext(nioServerChannel, null, selector, null,
exceptionHandler) {
ServerChannelContext context = new ServerChannelContext(nioServerChannel, this, selector, socketConfig,
MockNioTransport.this::acceptChannel, exceptionHandler) {
@Override
public void acceptChannels(Supplier<NioSelector> selectorSupplier) throws IOException {
int acceptCount = 0;
NioSocketChannel acceptedChannel;
while ((acceptedChannel = MockTcpChannelFactory.this.acceptNioChannel(this, selectorSupplier)) != null) {
acceptChannel(acceptedChannel);
SocketChannel acceptedChannel;
while ((acceptedChannel = accept(rawChannel)) != null) {
NioSocketChannel nioChannel = MockTcpChannelFactory.this.acceptNioChannel(acceptedChannel, selectorSupplier);
acceptChannel(nioChannel);
++acceptCount;
if (acceptCount % 100 == 0) {
logger.warn("Accepted [{}] connections in a single select loop iteration on [{}]", acceptCount, channel);
@ -272,11 +279,8 @@ public class MockNioTransport extends TcpTransport {
private static class MockServerChannel extends NioServerSocketChannel implements TcpServerChannel {
private final String profile;
MockServerChannel(String profile, ServerSocketChannel channel) {
MockServerChannel(ServerSocketChannel channel) {
super(channel);
this.profile = profile;
}
@Override
@ -284,11 +288,6 @@ public class MockNioTransport extends TcpTransport {
getContext().closeChannel();
}
@Override
public String getProfile() {
return profile;
}
@Override
public void addCloseListener(ActionListener<Void> listener) {
addCloseListener(ActionListener.toBiConsumer(listener));

View File

@ -13,6 +13,7 @@ import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.WriteOperation;
import javax.net.ssl.SSLEngine;
@ -41,15 +42,17 @@ public final class SSLChannelContext extends SocketChannelContext {
private final LinkedList<FlushOperation> encryptedFlushes = new LinkedList<>();
private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
NioChannelHandler readWriteHandler, InboundChannelBuffer applicationBuffer) {
this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(),
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig,
Consumer<Exception> exceptionHandler, SSLDriver sslDriver, NioChannelHandler readWriteHandler,
InboundChannelBuffer applicationBuffer) {
this(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(),
applicationBuffer);
}
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
NioChannelHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig,
Consumer<Exception> exceptionHandler, SSLDriver sslDriver, NioChannelHandler readWriteHandler,
InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer) {
super(channel, selector, socketConfig, exceptionHandler, readWriteHandler, channelBuffer);
this.sslDriver = sslDriver;
this.networkReadBuffer = networkReadBuffer;
}

View File

@ -18,6 +18,7 @@ import org.elasticsearch.http.nio.NioHttpServerChannel;
import org.elasticsearch.http.nio.NioHttpServerTransport;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector;
@ -82,17 +83,17 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
class SecurityHttpChannelFactory extends ChannelFactory<NioHttpServerChannel, NioHttpChannel> {
private SecurityHttpChannelFactory() {
super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize));
super(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize);
}
@Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
final NioChannelHandler handler;
if (ipFilter != null) {
handler = new NioIPFilter(httpHandler, httpChannel.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME);
handler = new NioIPFilter(httpHandler, socketConfig.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME);
} else {
handler = httpHandler;
}
@ -113,10 +114,10 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
}
SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false);
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer,
context = new SSLChannelContext(httpChannel, selector, socketConfig, exceptionHandler, sslDriver, handler, networkBuffer,
applicationBuffer);
} else {
context = new BytesChannelContext(httpChannel, selector, exceptionHandler, handler, networkBuffer);
context = new BytesChannelContext(httpChannel, selector, socketConfig, exceptionHandler, handler, networkBuffer);
}
httpChannel.setContext(context);
@ -124,11 +125,13 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
}
@Override
public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) {
public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
NioHttpServerChannel httpServerChannel = new NioHttpServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> onServerException(httpServerChannel, e);
Consumer<NioSocketChannel> acceptor = SecurityNioHttpServerTransport.this::acceptChannel;
ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, acceptor, exceptionHandler);
ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, socketConfig, acceptor,
exceptionHandler);
httpServerChannel.setContext(context);
return httpServerChannel;

View File

@ -16,7 +16,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector;
@ -111,9 +111,6 @@ public class SecurityNioTransport extends NioTransport {
@Override
protected Function<DiscoveryNode, TcpChannelFactory> clientChannelFactoryFunction(ProfileSettings profileSettings) {
return (node) -> {
final ChannelFactory.RawChannelFactory rawChannelFactory = new ChannelFactory.RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive, profileSettings.reuseAddress, Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes()));
SNIHostName serverName;
String configuredServerName = node.getAttributes().get("server_name");
if (configuredServerName != null) {
@ -125,7 +122,7 @@ public class SecurityNioTransport extends NioTransport {
} else {
serverName = null;
}
return new SecurityClientTcpChannelFactory(rawChannelFactory, serverName);
return new SecurityClientTcpChannelFactory(profileSettings, serverName);
};
}
@ -135,26 +132,18 @@ public class SecurityNioTransport extends NioTransport {
private final boolean isClient;
private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) {
this(new RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive,
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), profileSettings.profileName, isClient);
}
private SecurityTcpChannelFactory(RawChannelFactory rawChannelFactory, String profileName, boolean isClient) {
super(rawChannelFactory);
this.profileName = profileName;
super(profileSettings);
this.profileName = profileSettings.profileName;
this.isClient = isClient;
}
@Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this);
final NioChannelHandler handler;
if (ipFilter != null) {
handler = new NioIPFilter(readWriteHandler, nioChannel.getRemoteAddress(), ipFilter, profileName);
handler = new NioIPFilter(readWriteHandler, socketConfig.getRemoteAddress(), ipFilter, profileName);
} else {
handler = readWriteHandler;
}
@ -163,12 +152,12 @@ public class SecurityNioTransport extends NioTransport {
SocketChannelContext context;
if (sslEnabled) {
SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient);
SSLDriver sslDriver = new SSLDriver(createSSLEngine(socketConfig), pageAllocator, isClient);
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer,
context = new SSLChannelContext(nioChannel, selector, socketConfig, exceptionHandler, sslDriver, handler, networkBuffer,
applicationBuffer);
} else {
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, handler, networkBuffer);
context = new BytesChannelContext(nioChannel, selector, socketConfig, exceptionHandler, handler, networkBuffer);
}
nioChannel.setContext(context);
@ -176,24 +165,25 @@ public class SecurityNioTransport extends NioTransport {
}
@Override
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException {
NioTcpServerChannel nioChannel = new NioTcpServerChannel(profileName, channel);
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
NioTcpServerChannel nioChannel = new NioTcpServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> onServerException(nioChannel, e);
Consumer<NioSocketChannel> acceptor = SecurityNioTransport.this::acceptChannel;
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, acceptor, exceptionHandler);
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, socketConfig, acceptor, exceptionHandler);
nioChannel.setContext(context);
return nioChannel;
}
protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException {
protected SSLEngine createSSLEngine(Config.Socket socketConfig) throws IOException {
SSLEngine sslEngine;
SSLConfiguration defaultConfig = profileConfiguration.get(TransportSettings.DEFAULT_PROFILE);
SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig);
boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled();
if (hostnameVerificationEnabled) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) channel.getRemoteAddress();
if (hostnameVerificationEnabled && socketConfig.isAccepted() == false) {
InetSocketAddress remoteAddress = socketConfig.getRemoteAddress();
// we create the socket based on the name given. don't reverse DNS
sslEngine = sslService.createSSLEngine(sslConfig, inetSocketAddress.getHostString(), inetSocketAddress.getPort());
sslEngine = sslService.createSSLEngine(sslConfig, remoteAddress.getHostString(), remoteAddress.getPort());
} else {
sslEngine = sslService.createSSLEngine(sslConfig, null, -1);
}
@ -205,19 +195,20 @@ public class SecurityNioTransport extends NioTransport {
private final SNIHostName serverName;
private SecurityClientTcpChannelFactory(RawChannelFactory rawChannelFactory, SNIHostName serverName) {
super(rawChannelFactory, TransportSettings.DEFAULT_PROFILE, true);
private SecurityClientTcpChannelFactory(ProfileSettings profileSettings, SNIHostName serverName) {
super(profileSettings, true);
this.serverName = serverName;
}
@Override
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) {
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
throw new AssertionError("Cannot create TcpServerChannel with client factory");
}
@Override
protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException {
SSLEngine sslEngine = super.createSSLEngine(channel);
protected SSLEngine createSSLEngine(Config.Socket socketConfig) throws IOException {
SSLEngine sslEngine = super.createSSLEngine(socketConfig);
if (serverName != null) {
SSLParameters sslParameters = sslEngine.getSSLParameters();
sslParameters.setServerNames(Collections.singletonList(serverName));

View File

@ -14,6 +14,7 @@ import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.TaskScheduler;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.test.ESTestCase;
@ -23,6 +24,7 @@ import org.mockito.stubbing.Answer;
import javax.net.ssl.SSLException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
@ -55,6 +57,7 @@ public class SSLChannelContextTests extends ESTestCase {
private Consumer exceptionHandler;
private SSLDriver sslDriver;
private int messageLength;
private Config.Socket socketConfig;
@Before
@SuppressWarnings("unchecked")
@ -73,7 +76,8 @@ public class SSLChannelContextTests extends ESTestCase {
outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {}));
when(channel.getRawChannel()).thenReturn(rawChannel);
exceptionHandler = mock(Consumer.class);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
socketConfig = new Config.Socket(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, mock(InetSocketAddress.class), false);
context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context.setSelectionKey(mock(SelectionKey.class));
when(selector.isOnCurrentThread()).thenReturn(true);
@ -180,7 +184,7 @@ public class SSLChannelContextTests extends ESTestCase {
try (SocketChannel realChannel = SocketChannel.open()) {
when(channel.getRawChannel()).thenReturn(realChannel);
TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
when(channel.isOpen()).thenReturn(true);
context.closeFromSelector();
@ -332,8 +336,10 @@ public class SSLChannelContextTests extends ESTestCase {
try (SocketChannel realChannel = SocketChannel.open()) {
when(channel.getRawChannel()).thenReturn(realChannel);
TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context.setSelectionKey(mock(SelectionKey.class));
context.closeChannel();
ArgumentCaptor<WriteOperation> captor = ArgumentCaptor.forClass(WriteOperation.class);
verify(selector).queueWrite(captor.capture());
@ -373,7 +379,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(selector.rawSelector()).thenReturn(realSelector);
when(channel.getRawChannel()).thenReturn(realSocket);
TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context.channelActive();
verify(sslDriver).init();
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.env.Environment;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.http.nio.NioHttpChannel;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
@ -77,7 +78,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine engine = SSLEngineUtils.getSSLEngine(channel);
assertThat(engine.getNeedClientAuth(), is(false));
@ -99,7 +100,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine engine = SSLEngineUtils.getSSLEngine(channel);
assertThat(engine.getNeedClientAuth(), is(false));
assertThat(engine.getWantClientAuth(), is(true));
@ -120,7 +121,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine engine = SSLEngineUtils.getSSLEngine(channel);
assertThat(engine.getNeedClientAuth(), is(true));
assertThat(engine.getWantClientAuth(), is(false));
@ -141,7 +142,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine engine = SSLEngineUtils.getSSLEngine(channel);
assertThat(engine.getNeedClientAuth(), is(false));
assertThat(engine.getWantClientAuth(), is(false));
@ -159,7 +160,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine defaultEngine = SSLEngineUtils.getSSLEngine(channel);
settings = Settings.builder()
@ -173,7 +174,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
new NetworkService(Collections.emptyList()), mock(BigArrays.class), mock(PageCacheRecycler.class), mock(ThreadPool.class),
xContentRegistry(), new NullDispatcher(), mock(IPFilter.class), sslService, nioGroupFactory);
factory = transport.channelFactory();
channel = factory.createChannel(mock(NioSelector.class), socketChannel);
channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine customEngine = SSLEngineUtils.getSSLEngine(channel);
assertThat(customEngine.getEnabledProtocols(), arrayContaining("TLSv1.2"));
assertThat(customEngine.getEnabledProtocols(), not(equalTo(defaultEngine.getEnabledProtocols())));