Move nio channel initialization to event loop (#45155)

Currently in the transport-nio work we connect and bind channels on the
a thread before the channel is registered with a selector. Additionally,
it is at this point that we set all the socket options. This commit
moves these operations onto the event-loop after the channel has been
registered with a selector. It attempts to set the socket options for a
non-server channel at registration time. If that fails, it will attempt
to set the options after the channel is connected. This should fix
#41071.
This commit is contained in:
Tim Brooks 2019-08-02 17:31:31 -04:00 committed by GitHub
parent ffbe047c32
commit 984ba82251
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 552 additions and 332 deletions

View File

@ -24,9 +24,9 @@ import java.util.function.Consumer;
public class BytesChannelContext extends SocketChannelContext { public class BytesChannelContext extends SocketChannelContext {
public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig,
NioChannelHandler handler, InboundChannelBuffer channelBuffer) { Consumer<Exception> exceptionHandler, NioChannelHandler handler, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, handler, channelBuffer); super(channel, selector, socketConfig, exceptionHandler, handler, channelBuffer);
} }
@Override @Override

View File

@ -19,58 +19,69 @@
package org.elasticsearch.nio; package org.elasticsearch.nio;
import org.elasticsearch.common.CheckedRunnable;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.UncheckedIOException; import java.io.UncheckedIOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketException;
import java.nio.channels.ServerSocketChannel; import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.security.AccessController; import java.nio.channels.spi.AbstractSelectableChannel;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.function.Supplier; import java.util.function.Supplier;
public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel, Socket extends NioSocketChannel> { public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel, Socket extends NioSocketChannel> {
private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
private final boolean tcpReuseAddress;
private final int tcpSendBufferSize;
private final int tcpReceiveBufferSize;
private final ChannelFactory.RawChannelFactory rawChannelFactory; private final ChannelFactory.RawChannelFactory rawChannelFactory;
/** /**
* This will create a {@link ChannelFactory} using the raw channel factory passed to the constructor. * This will create a {@link ChannelFactory}.
*
* @param rawChannelFactory a factory that will construct the raw socket channels
*/ */
protected ChannelFactory(RawChannelFactory rawChannelFactory) { protected ChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReuseAddress, int tcpSendBufferSize,
int tcpReceiveBufferSize) {
this(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, new RawChannelFactory());
}
/**
* This will create a {@link ChannelFactory} using the raw channel factory passed to the constructor.
*/
protected ChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReuseAddress, int tcpSendBufferSize,
int tcpReceiveBufferSize, RawChannelFactory rawChannelFactory) {
this.tcpNoDelay = tcpNoDelay;
this.tcpKeepAlive = tcpKeepAlive;
this.tcpReuseAddress = tcpReuseAddress;
this.tcpSendBufferSize = tcpSendBufferSize;
this.tcpReceiveBufferSize = tcpReceiveBufferSize;
this.rawChannelFactory = rawChannelFactory; this.rawChannelFactory = rawChannelFactory;
} }
public Socket openNioChannel(InetSocketAddress remoteAddress, Supplier<NioSelector> supplier) throws IOException { public Socket openNioChannel(InetSocketAddress remoteAddress, Supplier<NioSelector> supplier) throws IOException {
SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress); SocketChannel rawChannel = rawChannelFactory.openNioChannel();
setNonBlocking(rawChannel);
NioSelector selector = supplier.get(); NioSelector selector = supplier.get();
Socket channel = internalCreateChannel(selector, rawChannel); Socket channel = internalCreateChannel(selector, rawChannel, createSocketConfig(remoteAddress, false));
scheduleChannel(channel, selector); scheduleChannel(channel, selector);
return channel; return channel;
} }
public Socket acceptNioChannel(ServerChannelContext serverContext, Supplier<NioSelector> supplier) throws IOException { public Socket acceptNioChannel(SocketChannel rawChannel, Supplier<NioSelector> supplier) throws IOException {
SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverContext); setNonBlocking(rawChannel);
// Null is returned if there are no pending sockets to accept NioSelector selector = supplier.get();
if (rawChannel == null) { InetSocketAddress remoteAddress = getRemoteAddress(rawChannel);
return null; Socket channel = internalCreateChannel(selector, rawChannel, createSocketConfig(remoteAddress, true));
} else { scheduleChannel(channel, selector);
NioSelector selector = supplier.get(); return channel;
Socket channel = internalCreateChannel(selector, rawChannel);
scheduleChannel(channel, selector);
return channel;
}
} }
public ServerSocket openNioServerSocketChannel(InetSocketAddress address, Supplier<NioSelector> supplier) throws IOException { public ServerSocket openNioServerSocketChannel(InetSocketAddress localAddress, Supplier<NioSelector> supplier) throws IOException {
ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(address); ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel();
setNonBlocking(rawChannel);
NioSelector selector = supplier.get(); NioSelector selector = supplier.get();
ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel); Config.ServerSocket config = new Config.ServerSocket(tcpReuseAddress, localAddress);
ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel, config);
scheduleServerChannel(serverChannel, selector); scheduleServerChannel(serverChannel, selector);
return serverChannel; return serverChannel;
} }
@ -80,27 +91,38 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
* returned, the channel should be fully created and setup. Read and write contexts and the channel * returned, the channel should be fully created and setup. Read and write contexts and the channel
* exception handler should have been set. * exception handler should have been set.
* *
* @param selector the channel will be registered with * @param selector the channel will be registered with
* @param channel the raw channel * @param channel the raw channel
* @param socketConfig the socket config
* @return the channel * @return the channel
* @throws IOException related to the creation of the channel * @throws IOException related to the creation of the channel
*/ */
public abstract Socket createChannel(NioSelector selector, SocketChannel channel) throws IOException; public abstract Socket createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException;
/** /**
* This method should return a new {@link NioServerSocketChannel} implementation. When this method has * This method should return a new {@link NioServerSocketChannel} implementation. When this method has
* returned, the channel should be fully created and setup. * returned, the channel should be fully created and setup.
* *
* @param selector the channel will be registered with * @param selector the channel will be registered with
* @param channel the raw channel * @param channel the raw channel
* @param socketConfig the socket config
* @return the server channel * @return the server channel
* @throws IOException related to the creation of the channel * @throws IOException related to the creation of the channel
*/ */
public abstract ServerSocket createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException; public abstract ServerSocket createServerChannel(NioSelector selector, ServerSocketChannel channel, Config.ServerSocket socketConfig)
throws IOException;
private Socket internalCreateChannel(NioSelector selector, SocketChannel rawChannel) throws IOException { protected InetSocketAddress getRemoteAddress(SocketChannel rawChannel) throws IOException {
InetSocketAddress remoteAddress = (InetSocketAddress) rawChannel.socket().getRemoteSocketAddress();
if (remoteAddress == null) {
throw new IOException("Accepted socket does not have remote address");
}
return remoteAddress;
}
private Socket internalCreateChannel(NioSelector selector, SocketChannel rawChannel, Config.Socket config) throws IOException {
try { try {
Socket channel = createChannel(selector, rawChannel); Socket channel = createChannel(selector, rawChannel, config);
assert channel.getContext() != null : "channel context should have been set on channel"; assert channel.getContext() != null : "channel context should have been set on channel";
return channel; return channel;
} catch (UncheckedIOException e) { } catch (UncheckedIOException e) {
@ -114,9 +136,10 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
} }
} }
private ServerSocket internalCreateServerChannel(NioSelector selector, ServerSocketChannel rawChannel) throws IOException { private ServerSocket internalCreateServerChannel(NioSelector selector, ServerSocketChannel rawChannel, Config.ServerSocket config)
throws IOException {
try { try {
return createServerChannel(selector, rawChannel); return createServerChannel(selector, rawChannel, config);
} catch (Exception e) { } catch (Exception e) {
closeRawChannel(rawChannel, e); closeRawChannel(rawChannel, e);
throw e; throw e;
@ -141,6 +164,15 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
} }
} }
private void setNonBlocking(AbstractSelectableChannel rawChannel) throws IOException {
try {
rawChannel.configureBlocking(false);
} catch (IOException e) {
closeRawChannel(rawChannel, e);
throw e;
}
}
private static void closeRawChannel(Closeable c, Exception e) { private static void closeRawChannel(Closeable c, Exception e) {
try { try {
c.close(); c.close();
@ -149,107 +181,19 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
} }
} }
private Config.Socket createSocketConfig(InetSocketAddress remoteAddress, boolean isAccepted) {
return new Config.Socket(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, remoteAddress,
isAccepted);
}
public static class RawChannelFactory { public static class RawChannelFactory {
private final boolean tcpNoDelay; SocketChannel openNioChannel() throws IOException {
private final boolean tcpKeepAlive; return SocketChannel.open();
private final boolean tcpReusedAddress;
private final int tcpSendBufferSize;
private final int tcpReceiveBufferSize;
public RawChannelFactory(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReusedAddress, int tcpSendBufferSize,
int tcpReceiveBufferSize) {
this.tcpNoDelay = tcpNoDelay;
this.tcpKeepAlive = tcpKeepAlive;
this.tcpReusedAddress = tcpReusedAddress;
this.tcpSendBufferSize = tcpSendBufferSize;
this.tcpReceiveBufferSize = tcpReceiveBufferSize;
} }
SocketChannel openNioChannel(InetSocketAddress remoteAddress) throws IOException { ServerSocketChannel openNioServerSocketChannel() throws IOException {
SocketChannel socketChannel = SocketChannel.open(); return ServerSocketChannel.open();
try {
configureSocketChannel(socketChannel);
connect(socketChannel, remoteAddress);
} catch (IOException e) {
closeRawChannel(socketChannel, e);
throw e;
}
return socketChannel;
}
SocketChannel acceptNioChannel(ServerChannelContext serverContext) throws IOException {
ServerSocketChannel rawChannel = serverContext.getChannel().getRawChannel();
assert rawChannel.isBlocking() == false;
SocketChannel socketChannel = accept(rawChannel);
assert rawChannel.isBlocking() == false;
if (socketChannel == null) {
return null;
}
try {
configureSocketChannel(socketChannel);
} catch (IOException e) {
closeRawChannel(socketChannel, e);
throw e;
}
return socketChannel;
}
ServerSocketChannel openNioServerSocketChannel(InetSocketAddress address) throws IOException {
ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.configureBlocking(false);
java.net.ServerSocket socket = serverSocketChannel.socket();
try {
socket.setReuseAddress(tcpReusedAddress);
serverSocketChannel.bind(address);
} catch (IOException e) {
closeRawChannel(serverSocketChannel, e);
throw e;
}
return serverSocketChannel;
}
private static final boolean MAC_OS_X = System.getProperty("os.name").startsWith("Mac OS X");
private static void setSocketOption(CheckedRunnable<SocketException> runnable) throws SocketException {
try {
runnable.run();
} catch (SocketException e) {
if (MAC_OS_X == false) {
// ignore on Mac, see https://github.com/elastic/elasticsearch/issues/41071
throw e;
}
}
}
private void configureSocketChannel(SocketChannel channel) throws IOException {
channel.configureBlocking(false);
java.net.Socket socket = channel.socket();
setSocketOption(() -> socket.setTcpNoDelay(tcpNoDelay));
setSocketOption(() -> socket.setKeepAlive(tcpKeepAlive));
setSocketOption(() -> socket.setReuseAddress(tcpReusedAddress));
if (tcpSendBufferSize > 0) {
setSocketOption(() -> socket.setSendBufferSize(tcpSendBufferSize));
}
if (tcpReceiveBufferSize > 0) {
setSocketOption(() -> socket.setSendBufferSize(tcpReceiveBufferSize));
}
}
public static SocketChannel accept(ServerSocketChannel serverSocketChannel) throws IOException {
try {
return AccessController.doPrivileged((PrivilegedExceptionAction<SocketChannel>) serverSocketChannel::accept);
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}
private static void connect(SocketChannel socketChannel, InetSocketAddress remoteAddress) throws IOException {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Boolean>) () -> socketChannel.connect(remoteAddress));
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
} }
} }
} }

View File

@ -0,0 +1,94 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.nio;
import java.net.InetSocketAddress;
public abstract class Config {
private final boolean tcpReuseAddress;
public Config(boolean tcpReuseAddress) {
this.tcpReuseAddress = tcpReuseAddress;
}
public boolean tcpReuseAddress() {
return tcpReuseAddress;
}
public static class Socket extends Config {
private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
private final int tcpSendBufferSize;
private final int tcpReceiveBufferSize;
private final InetSocketAddress remoteAddress;
private final boolean isAccepted;
public Socket(boolean tcpNoDelay, boolean tcpKeepAlive, boolean tcpReuseAddress, int tcpSendBufferSize, int tcpReceiveBufferSize,
InetSocketAddress remoteAddress, boolean isAccepted) {
super(tcpReuseAddress);
this.tcpNoDelay = tcpNoDelay;
this.tcpKeepAlive = tcpKeepAlive;
this.tcpSendBufferSize = tcpSendBufferSize;
this.tcpReceiveBufferSize = tcpReceiveBufferSize;
this.remoteAddress = remoteAddress;
this.isAccepted = isAccepted;
}
public boolean tcpNoDelay() {
return tcpNoDelay;
}
public boolean tcpKeepAlive() {
return tcpKeepAlive;
}
public int tcpSendBufferSize() {
return tcpSendBufferSize;
}
public int tcpReceiveBufferSize() {
return tcpReceiveBufferSize;
}
public boolean isAccepted() {
return isAccepted;
}
public InetSocketAddress getRemoteAddress() {
return remoteAddress;
}
}
public static class ServerSocket extends Config {
private InetSocketAddress localAddress;
public ServerSocket(boolean tcpReuseAddress, InetSocketAddress localAddress) {
super(tcpReuseAddress);
this.localAddress = localAddress;
}
public InetSocketAddress getLocalAddress() {
return localAddress;
}
}
}

View File

@ -22,6 +22,7 @@ package org.elasticsearch.nio;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel; import java.nio.channels.ServerSocketChannel;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
public class NioServerSocketChannel extends NioChannel { public class NioServerSocketChannel extends NioChannel {
@ -32,7 +33,6 @@ public class NioServerSocketChannel extends NioChannel {
public NioServerSocketChannel(ServerSocketChannel serverSocketChannel) { public NioServerSocketChannel(ServerSocketChannel serverSocketChannel) {
this.serverSocketChannel = serverSocketChannel; this.serverSocketChannel = serverSocketChannel;
attemptToSetLocalAddress();
} }
/** /**
@ -49,6 +49,10 @@ public class NioServerSocketChannel extends NioChannel {
} }
} }
public void addBindListener(BiConsumer<Void, Exception> listener) {
context.addBindListener(listener);
}
@Override @Override
public InetSocketAddress getLocalAddress() { public InetSocketAddress getLocalAddress() {
attemptToSetLocalAddress(); attemptToSetLocalAddress();

View File

@ -19,8 +19,6 @@
package org.elasticsearch.nio; package org.elasticsearch.nio;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@ -30,17 +28,12 @@ public class NioSocketChannel extends NioChannel {
private final AtomicBoolean contextSet = new AtomicBoolean(false); private final AtomicBoolean contextSet = new AtomicBoolean(false);
private final SocketChannel socketChannel; private final SocketChannel socketChannel;
private final InetSocketAddress remoteAddress; private volatile InetSocketAddress remoteAddress;
private volatile InetSocketAddress localAddress; private volatile InetSocketAddress localAddress;
private SocketChannelContext context; private volatile SocketChannelContext context;
public NioSocketChannel(SocketChannel socketChannel) { public NioSocketChannel(SocketChannel socketChannel) {
this.socketChannel = socketChannel; this.socketChannel = socketChannel;
try {
this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
} }
public void setContext(SocketChannelContext context) { public void setContext(SocketChannelContext context) {
@ -70,6 +63,9 @@ public class NioSocketChannel extends NioChannel {
} }
public InetSocketAddress getRemoteAddress() { public InetSocketAddress getRemoteAddress() {
if (remoteAddress == null) {
remoteAddress = (InetSocketAddress) socketChannel.socket().getRemoteSocketAddress();
}
return remoteAddress; return remoteAddress;
} }
@ -81,7 +77,7 @@ public class NioSocketChannel extends NioChannel {
public String toString() { public String toString() {
return "NioSocketChannel{" + return "NioSocketChannel{" +
"localAddress=" + getLocalAddress() + "localAddress=" + getLocalAddress() +
", remoteAddress=" + remoteAddress + ", remoteAddress=" + getRemoteAddress() +
'}'; '}';
} }
} }

View File

@ -19,9 +19,18 @@
package org.elasticsearch.nio; package org.elasticsearch.nio;
import org.elasticsearch.common.concurrent.CompletableContext;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.nio.channels.ServerSocketChannel; import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -29,23 +38,49 @@ public class ServerChannelContext extends ChannelContext<ServerSocketChannel> {
private final NioServerSocketChannel channel; private final NioServerSocketChannel channel;
private final NioSelector selector; private final NioSelector selector;
private final Config.ServerSocket config;
private final Consumer<NioSocketChannel> acceptor; private final Consumer<NioSocketChannel> acceptor;
private final AtomicBoolean isClosing = new AtomicBoolean(false); private final AtomicBoolean isClosing = new AtomicBoolean(false);
private final ChannelFactory<?, ?> channelFactory; private final ChannelFactory<?, ?> channelFactory;
private final CompletableContext<Void> bindContext = new CompletableContext<>();
public ServerChannelContext(NioServerSocketChannel channel, ChannelFactory<?, ?> channelFactory, NioSelector selector, public ServerChannelContext(NioServerSocketChannel channel, ChannelFactory<?, ?> channelFactory, NioSelector selector,
Consumer<NioSocketChannel> acceptor, Consumer<Exception> exceptionHandler) { Config.ServerSocket config, Consumer<NioSocketChannel> acceptor,
Consumer<Exception> exceptionHandler) {
super(channel.getRawChannel(), exceptionHandler); super(channel.getRawChannel(), exceptionHandler);
this.channel = channel; this.channel = channel;
this.channelFactory = channelFactory; this.channelFactory = channelFactory;
this.selector = selector; this.selector = selector;
this.config = config;
this.acceptor = acceptor; this.acceptor = acceptor;
} }
public void acceptChannels(Supplier<NioSelector> selectorSupplier) throws IOException { public void acceptChannels(Supplier<NioSelector> selectorSupplier) throws IOException {
NioSocketChannel acceptedChannel; SocketChannel acceptedChannel;
while ((acceptedChannel = channelFactory.acceptNioChannel(this, selectorSupplier)) != null) { while ((acceptedChannel = accept(rawChannel)) != null) {
acceptor.accept(acceptedChannel); NioSocketChannel nioChannel = channelFactory.acceptNioChannel(acceptedChannel, selectorSupplier);
acceptor.accept(nioChannel);
}
}
public void addBindListener(BiConsumer<Void, Exception> listener) {
bindContext.addListener(listener);
}
@Override
protected void register() throws IOException {
super.register();
configureSocket(rawChannel.socket());
InetSocketAddress localAddress = config.getLocalAddress();
try {
rawChannel.bind(localAddress);
bindContext.complete(null);
} catch (IOException e) {
IOException exception = new IOException("Failed to bind server socket channel {localAddress=" + localAddress + "}.", e);
bindContext.completeExceptionally(exception);
throw exception;
} }
} }
@ -66,4 +101,18 @@ public class ServerChannelContext extends ChannelContext<ServerSocketChannel> {
return channel; return channel;
} }
private void configureSocket(ServerSocket socket) throws IOException {
socket.setReuseAddress(config.tcpReuseAddress());
}
protected static SocketChannel accept(ServerSocketChannel serverSocketChannel) throws IOException {
try {
assert serverSocketChannel.isBlocking() == false;
SocketChannel channel = AccessController.doPrivileged((PrivilegedExceptionAction<SocketChannel>) serverSocketChannel::accept);
assert serverSocketChannel.isBlocking() == false;
return channel;
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}
} }

View File

@ -24,9 +24,14 @@ import org.elasticsearch.nio.utils.ByteBufferUtils;
import org.elasticsearch.nio.utils.ExceptionsHelper; import org.elasticsearch.nio.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@ -47,19 +52,23 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
protected final NioSocketChannel channel; protected final NioSocketChannel channel;
protected final InboundChannelBuffer channelBuffer; protected final InboundChannelBuffer channelBuffer;
protected final AtomicBoolean isClosing = new AtomicBoolean(false); protected final AtomicBoolean isClosing = new AtomicBoolean(false);
private final NioChannelHandler readWriteHandler; private final NioChannelHandler channelHandler;
private final NioSelector selector; private final NioSelector selector;
private final Config.Socket socketConfig;
private final CompletableContext<Void> connectContext = new CompletableContext<>(); private final CompletableContext<Void> connectContext = new CompletableContext<>();
private final LinkedList<FlushOperation> pendingFlushes = new LinkedList<>(); private final LinkedList<FlushOperation> pendingFlushes = new LinkedList<>();
private boolean closeNow; private boolean closeNow;
private boolean socketOptionsSet;
private Exception connectException; private Exception connectException;
protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig,
NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) { Consumer<Exception> exceptionHandler, NioChannelHandler channelHandler,
InboundChannelBuffer channelBuffer) {
super(channel.getRawChannel(), exceptionHandler); super(channel.getRawChannel(), exceptionHandler);
this.selector = selector; this.selector = selector;
this.channel = channel; this.channel = channel;
this.readWriteHandler = readWriteHandler; this.socketConfig = socketConfig;
this.channelHandler = channelHandler;
this.channelBuffer = channelBuffer; this.channelBuffer = channelBuffer;
} }
@ -73,6 +82,22 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
return channel; return channel;
} }
@Override
protected void register() throws IOException {
super.register();
configureSocket(rawChannel.socket(), false);
if (socketConfig.isAccepted() == false) {
InetSocketAddress remoteAddress = socketConfig.getRemoteAddress();
try {
connect(rawChannel, remoteAddress);
} catch (IOException e) {
throw new IOException("Failed to initiate socket channel connection {remoteAddress=" + remoteAddress + "}.", e);
}
}
}
public void addConnectListener(BiConsumer<Void, Exception> listener) { public void addConnectListener(BiConsumer<Void, Exception> listener) {
connectContext.addListener(listener); connectContext.addListener(listener);
} }
@ -117,6 +142,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
} }
if (isConnected) { if (isConnected) {
connectContext.complete(null); connectContext.complete(null);
configureSocket(rawChannel.socket(), true);
} }
return isConnected; return isConnected;
} }
@ -127,14 +153,14 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
return; return;
} }
WriteOperation writeOperation = readWriteHandler.createWriteOperation(this, message, listener); WriteOperation writeOperation = channelHandler.createWriteOperation(this, message, listener);
getSelector().queueWrite(writeOperation); getSelector().queueWrite(writeOperation);
} }
public void queueWriteOperation(WriteOperation writeOperation) { public void queueWriteOperation(WriteOperation writeOperation) {
getSelector().assertOnSelectorThread(); getSelector().assertOnSelectorThread();
pendingFlushes.addAll(readWriteHandler.writeToBytes(writeOperation)); pendingFlushes.addAll(channelHandler.writeToBytes(writeOperation));
} }
public abstract int read() throws IOException; public abstract int read() throws IOException;
@ -157,7 +183,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
@Override @Override
protected void channelActive() throws IOException { protected void channelActive() throws IOException {
readWriteHandler.channelActive(); channelHandler.channelActive();
} }
@Override @Override
@ -174,14 +200,14 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
isClosing.set(true); isClosing.set(true);
// Poll for new flush operations to close // Poll for new flush operations to close
pendingFlushes.addAll(readWriteHandler.pollFlushOperations()); pendingFlushes.addAll(channelHandler.pollFlushOperations());
FlushOperation flushOperation; FlushOperation flushOperation;
while ((flushOperation = pendingFlushes.pollFirst()) != null) { while ((flushOperation = pendingFlushes.pollFirst()) != null) {
selector.executeFailedListener(flushOperation.getListener(), new ClosedChannelException()); selector.executeFailedListener(flushOperation.getListener(), new ClosedChannelException());
} }
try { try {
readWriteHandler.close(); channelHandler.close();
} catch (IOException e) { } catch (IOException e) {
closingExceptions.add(e); closingExceptions.add(e);
} }
@ -196,12 +222,12 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
protected void handleReadBytes() throws IOException { protected void handleReadBytes() throws IOException {
int bytesConsumed = Integer.MAX_VALUE; int bytesConsumed = Integer.MAX_VALUE;
while (isOpen() && bytesConsumed > 0 && channelBuffer.getIndex() > 0) { while (isOpen() && bytesConsumed > 0 && channelBuffer.getIndex() > 0) {
bytesConsumed = readWriteHandler.consumeReads(channelBuffer); bytesConsumed = channelHandler.consumeReads(channelBuffer);
channelBuffer.release(bytesConsumed); channelBuffer.release(bytesConsumed);
} }
// Some protocols might produce messages to flush during a read operation. // Some protocols might produce messages to flush during a read operation.
pendingFlushes.addAll(readWriteHandler.pollFlushOperations()); pendingFlushes.addAll(channelHandler.pollFlushOperations());
} }
public boolean readyForFlush() { public boolean readyForFlush() {
@ -217,7 +243,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
public abstract boolean selectorShouldClose(); public abstract boolean selectorShouldClose();
protected boolean closeNow() { protected boolean closeNow() {
return closeNow || readWriteHandler.closeNow(); return closeNow || channelHandler.closeNow();
} }
protected void setCloseNow() { protected void setCloseNow() {
@ -288,4 +314,42 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
} }
return totalBytesFlushed; return totalBytesFlushed;
} }
private void configureSocket(Socket socket, boolean isConnectComplete) throws IOException {
if (socketOptionsSet) {
return;
}
try {
// Set reuse address first as it must be set before a bind call. Some implementations throw
// exceptions on other socket options if the channel is not connected. But setting reuse first,
// we ensure that it is properly set before any bind attempt.
socket.setReuseAddress(socketConfig.tcpReuseAddress());
socket.setKeepAlive(socketConfig.tcpKeepAlive());
socket.setTcpNoDelay(socketConfig.tcpNoDelay());
int tcpSendBufferSize = socketConfig.tcpSendBufferSize();
if (tcpSendBufferSize > 0) {
socket.setSendBufferSize(tcpSendBufferSize);
}
int tcpReceiveBufferSize = socketConfig.tcpReceiveBufferSize();
if (tcpReceiveBufferSize > 0) {
socket.setReceiveBufferSize(tcpReceiveBufferSize);
}
socketOptionsSet = true;
} catch (IOException e) {
if (isConnectComplete) {
throw e;
}
// Ignore if not connect complete. Some implementations fail on setting socket options if the
// socket is not connected. We will try again after connection.
}
}
private static void connect(SocketChannel socketChannel, InetSocketAddress remoteAddress) throws IOException {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Boolean>) () -> socketChannel.connect(remoteAddress));
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}
} }

View File

@ -62,7 +62,7 @@ public class BytesChannelContextTests extends ESTestCase {
channelBuffer = InboundChannelBuffer.allocatingInstance(); channelBuffer = InboundChannelBuffer.allocatingInstance();
TestReadWriteHandler handler = new TestReadWriteHandler(readConsumer); TestReadWriteHandler handler = new TestReadWriteHandler(readConsumer);
when(channel.getRawChannel()).thenReturn(rawChannel); when(channel.getRawChannel()).thenReturn(rawChannel);
context = new BytesChannelContext(channel, selector, mock(Consumer.class), handler, channelBuffer); context = new BytesChannelContext(channel, selector, mock(Config.Socket.class), mock(Consumer.class), handler, channelBuffer);
when(selector.isOnCurrentThread()).thenReturn(true); when(selector.isOnCurrentThread()).thenReturn(true);
ByteBuffer buffer = ByteBuffer.allocate(1 << 14); ByteBuffer buffer = ByteBuffer.allocate(1 << 14);

View File

@ -31,7 +31,6 @@ import java.nio.channels.SocketChannel;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -69,10 +68,7 @@ public class ChannelFactoryTests extends ESTestCase {
} }
public void testAcceptChannel() throws IOException { public void testAcceptChannel() throws IOException {
ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); NioSocketChannel channel = channelFactory.acceptNioChannel(rawChannel, socketSelectorSupplier);
when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel);
NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier);
verify(socketSelector).scheduleForRegistration(channel); verify(socketSelector).scheduleForRegistration(channel);
@ -80,18 +76,16 @@ public class ChannelFactoryTests extends ESTestCase {
} }
public void testAcceptedChannelRejected() throws IOException { public void testAcceptedChannelRejected() throws IOException {
ServerChannelContext serverChannelContext = mock(ServerChannelContext.class);
when(rawChannelFactory.acceptNioChannel(serverChannelContext)).thenReturn(rawChannel);
doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any());
expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannelContext, socketSelectorSupplier)); expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(rawChannel, socketSelectorSupplier));
assertFalse(rawChannel.isOpen()); assertFalse(rawChannel.isOpen());
} }
public void testOpenChannel() throws IOException { public void testOpenChannel() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class); InetSocketAddress address = mock(InetSocketAddress.class);
when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); when(rawChannelFactory.openNioChannel()).thenReturn(rawChannel);
NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelectorSupplier); NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelectorSupplier);
@ -102,7 +96,7 @@ public class ChannelFactoryTests extends ESTestCase {
public void testOpenedChannelRejected() throws IOException { public void testOpenedChannelRejected() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class); InetSocketAddress address = mock(InetSocketAddress.class);
when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); when(rawChannelFactory.openNioChannel()).thenReturn(rawChannel);
doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any());
expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelectorSupplier)); expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelectorSupplier));
@ -112,7 +106,7 @@ public class ChannelFactoryTests extends ESTestCase {
public void testOpenServerChannel() throws IOException { public void testOpenServerChannel() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class); InetSocketAddress address = mock(InetSocketAddress.class);
when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); when(rawChannelFactory.openNioServerSocketChannel()).thenReturn(rawServerChannel);
NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier); NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier);
@ -123,7 +117,7 @@ public class ChannelFactoryTests extends ESTestCase {
public void testOpenedServerChannelRejected() throws IOException { public void testOpenedServerChannelRejected() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class); InetSocketAddress address = mock(InetSocketAddress.class);
when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel); when(rawChannelFactory.openNioServerSocketChannel()).thenReturn(rawServerChannel);
doThrow(new IllegalStateException()).when(acceptingSelector).scheduleForRegistration(any()); doThrow(new IllegalStateException()).when(acceptingSelector).scheduleForRegistration(any());
expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier)); expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel(address, acceptingSelectorSupplier));
@ -134,19 +128,26 @@ public class ChannelFactoryTests extends ESTestCase {
private static class TestChannelFactory extends ChannelFactory<NioServerSocketChannel, NioSocketChannel> { private static class TestChannelFactory extends ChannelFactory<NioServerSocketChannel, NioSocketChannel> {
TestChannelFactory(RawChannelFactory rawChannelFactory) { TestChannelFactory(RawChannelFactory rawChannelFactory) {
super(rawChannelFactory); super(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, rawChannelFactory);
} }
@Override @Override
public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
NioSocketChannel nioSocketChannel = new NioSocketChannel(channel); NioSocketChannel nioSocketChannel = new NioSocketChannel(channel);
nioSocketChannel.setContext(mock(SocketChannelContext.class)); nioSocketChannel.setContext(mock(SocketChannelContext.class));
return nioSocketChannel; return nioSocketChannel;
} }
@Override @Override
public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException { public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
return new NioServerSocketChannel(channel); return new NioServerSocketChannel(channel);
} }
@Override
protected InetSocketAddress getRemoteAddress(SocketChannel rawChannel) throws IOException {
// Override this functionality to avoid having to connect the accepted channel
return mock(InetSocketAddress.class);
}
} }
} }

View File

@ -23,7 +23,9 @@ import org.elasticsearch.test.ESTestCase;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.net.Socket;
import java.nio.channels.CancelledKeyException; import java.nio.channels.CancelledKeyException;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel; import java.nio.channels.ServerSocketChannel;
@ -33,7 +35,6 @@ import java.util.Collections;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer; import java.util.function.Consumer;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -67,6 +68,7 @@ public class EventHandlerTests extends ESTestCase {
SocketChannel rawChannel = mock(SocketChannel.class); SocketChannel rawChannel = mock(SocketChannel.class);
when(rawChannel.finishConnect()).thenReturn(true); when(rawChannel.finishConnect()).thenReturn(true);
NioSocketChannel channel = new NioSocketChannel(rawChannel); NioSocketChannel channel = new NioSocketChannel(rawChannel);
when(rawChannel.socket()).thenReturn(mock(Socket.class));
context = new DoNotRegisterSocketContext(channel, selector, channelExceptionHandler, readWriteHandler); context = new DoNotRegisterSocketContext(channel, selector, channelExceptionHandler, readWriteHandler);
channel.setContext(context); channel.setContext(context);
handler.handleRegistration(context); handler.handleRegistration(context);
@ -104,22 +106,8 @@ public class EventHandlerTests extends ESTestCase {
assertEquals(SelectionKey.OP_ACCEPT, serverContext.getSelectionKey().interestOps()); assertEquals(SelectionKey.OP_ACCEPT, serverContext.getSelectionKey().interestOps());
} }
public void testHandleAcceptCallsChannelFactory() throws IOException { public void testHandleAcceptAccept() throws IOException {
NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class));
NioSocketChannel nullChannel = null;
when(channelFactory.acceptNioChannel(same(serverContext), same(selectorSupplier))).thenReturn(childChannel, nullChannel);
handler.acceptChannel(serverContext);
verify(channelFactory, times(2)).acceptNioChannel(same(serverContext), same(selectorSupplier));
}
public void testHandleAcceptCallsServerAcceptCallback() throws IOException {
NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class));
SocketChannelContext childContext = mock(SocketChannelContext.class);
childChannel.setContext(childContext);
ServerChannelContext serverChannelContext = mock(ServerChannelContext.class); ServerChannelContext serverChannelContext = mock(ServerChannelContext.class);
when(channelFactory.acceptNioChannel(same(serverContext), same(selectorSupplier))).thenReturn(childChannel);
handler.acceptChannel(serverChannelContext); handler.acceptChannel(serverChannelContext);
@ -254,7 +242,7 @@ public class EventHandlerTests extends ESTestCase {
DoNotRegisterSocketContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, DoNotRegisterSocketContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
NioChannelHandler handler) { NioChannelHandler handler) {
super(channel, selector, exceptionHandler, handler, InboundChannelBuffer.allocatingInstance()); super(channel, selector, getSocketConfig(), exceptionHandler, handler, InboundChannelBuffer.allocatingInstance());
} }
@Override @Override
@ -270,7 +258,7 @@ public class EventHandlerTests extends ESTestCase {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
DoNotRegisterServerContext(NioServerSocketChannel channel, NioSelector selector, Consumer<NioSocketChannel> acceptor) { DoNotRegisterServerContext(NioServerSocketChannel channel, NioSelector selector, Consumer<NioSocketChannel> acceptor) {
super(channel, channelFactory, selector, acceptor, mock(Consumer.class)); super(channel, channelFactory, selector, getServerSocketConfig(), acceptor, mock(Consumer.class));
} }
@Override @Override
@ -280,4 +268,13 @@ public class EventHandlerTests extends ESTestCase {
selectionKey.attach(this); selectionKey.attach(this);
} }
} }
private static Config.ServerSocket getServerSocketConfig() {
return new Config.ServerSocket(randomBoolean(), mock(InetSocketAddress.class));
}
private static Config.Socket getSocketConfig() {
return new Config.Socket(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, mock(InetSocketAddress.class),
randomBoolean());
}
} }

View File

@ -19,12 +19,16 @@
package org.elasticsearch.nio; package org.elasticsearch.nio;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.junit.Before; import org.junit.Before;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
@ -40,11 +44,13 @@ import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.isNull; import static org.mockito.Matchers.isNull;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@SuppressForbidden(reason = "allow call to socket connect")
public class SocketChannelContextTests extends ESTestCase { public class SocketChannelContextTests extends ESTestCase {
private SocketChannel rawChannel; private SocketChannel rawChannel;
@ -55,6 +61,7 @@ public class SocketChannelContextTests extends ESTestCase {
private NioSelector selector; private NioSelector selector;
private NioChannelHandler handler; private NioChannelHandler handler;
private ByteBuffer ioBuffer = ByteBuffer.allocate(1024); private ByteBuffer ioBuffer = ByteBuffer.allocate(1024);
private Socket rawSocket;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Before @Before
@ -76,6 +83,8 @@ public class SocketChannelContextTests extends ESTestCase {
ioBuffer.clear(); ioBuffer.clear();
return ioBuffer; return ioBuffer;
}); });
rawSocket = mock(Socket.class);
when(rawChannel.socket()).thenReturn(rawSocket);
} }
public void testIOExceptionSetIfEncountered() throws IOException { public void testIOExceptionSetIfEncountered() throws IOException {
@ -101,6 +110,31 @@ public class SocketChannelContextTests extends ESTestCase {
assertTrue(context.closeNow()); assertTrue(context.closeNow());
} }
public void testRegisterInitiatesConnect() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class);
boolean isAccepted = randomBoolean();
Config.Socket config;
boolean tcpNoDelay = randomBoolean();
boolean tcpKeepAlive = randomBoolean();
boolean tcpReuseAddress = randomBoolean();
int tcpSendBufferSize = randomIntBetween(1000, 2000);
int tcpReceiveBufferSize = randomIntBetween(1000, 2000);
config = new Config.Socket(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, address, isAccepted);
InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance();
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, buffer, config);
context.register();
if (isAccepted) {
verify(rawChannel, times(0)).connect(any(InetSocketAddress.class));
} else {
verify(rawChannel).connect(same(address));
}
verify(rawSocket).setTcpNoDelay(tcpNoDelay);
verify(rawSocket).setKeepAlive(tcpKeepAlive);
verify(rawSocket).setReuseAddress(tcpReuseAddress);
verify(rawSocket).setSendBufferSize(tcpSendBufferSize);
verify(rawSocket).setReceiveBufferSize(tcpReceiveBufferSize);
}
public void testConnectSucceeds() throws IOException { public void testConnectSucceeds() throws IOException {
AtomicBoolean listenerCalled = new AtomicBoolean(false); AtomicBoolean listenerCalled = new AtomicBoolean(false);
when(rawChannel.finishConnect()).thenReturn(false, true); when(rawChannel.finishConnect()).thenReturn(false, true);
@ -142,6 +176,29 @@ public class SocketChannelContextTests extends ESTestCase {
assertSame(ioException, exception.get()); assertSame(ioException, exception.get());
} }
public void testConnectCanSetSocketOptions() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class);
Config.Socket config;
boolean tcpNoDelay = randomBoolean();
boolean tcpKeepAlive = randomBoolean();
boolean tcpReuseAddress = randomBoolean();
int tcpSendBufferSize = randomIntBetween(1000, 2000);
int tcpReceiveBufferSize = randomIntBetween(1000, 2000);
config = new Config.Socket(tcpNoDelay, tcpKeepAlive, tcpReuseAddress, tcpSendBufferSize, tcpReceiveBufferSize, address, false);
InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance();
TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, handler, buffer, config);
doThrow(new SocketException()).doNothing().when(rawSocket).setReuseAddress(tcpReuseAddress);
context.register();
when(rawChannel.finishConnect()).thenReturn(true);
context.connect();
verify(rawSocket, times(2)).setReuseAddress(tcpReuseAddress);
verify(rawSocket).setKeepAlive(tcpKeepAlive);
verify(rawSocket).setTcpNoDelay(tcpNoDelay);
verify(rawSocket).setSendBufferSize(tcpSendBufferSize);
verify(rawSocket).setReceiveBufferSize(tcpReceiveBufferSize);
}
public void testChannelActiveCallsHandler() throws IOException { public void testChannelActiveCallsHandler() throws IOException {
context.channelActive(); context.channelActive();
verify(handler).channelActive(); verify(handler).channelActive();
@ -262,7 +319,8 @@ public class SocketChannelContextTests extends ESTestCase {
when(channel.getRawChannel()).thenReturn(realChannel); when(channel.getRawChannel()).thenReturn(realChannel);
when(channel.isOpen()).thenReturn(true); when(channel.isOpen()).thenReturn(true);
InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance(); InboundChannelBuffer buffer = InboundChannelBuffer.allocatingInstance();
BytesChannelContext context = new BytesChannelContext(channel, selector, exceptionHandler, handler, buffer); BytesChannelContext context = new BytesChannelContext(channel, selector, mock(Config.Socket.class), exceptionHandler, handler,
buffer);
context.closeFromSelector(); context.closeFromSelector();
verify(handler).close(); verify(handler).close();
} }
@ -379,11 +437,21 @@ public class SocketChannelContextTests extends ESTestCase {
assertEquals(1, flushOperation.getBuffersToWrite()[0].position()); assertEquals(1, flushOperation.getBuffersToWrite()[0].position());
} }
private static Config.Socket getSocketConfig() {
return new Config.Socket(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, mock(InetSocketAddress.class),
randomBoolean());
}
private static class TestSocketChannelContext extends SocketChannelContext { private static class TestSocketChannelContext extends SocketChannelContext {
private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) { NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); this(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, getSocketConfig());
}
private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer, Config.Socket config) {
super(channel, selector, config, exceptionHandler, readWriteHandler, channelBuffer);
} }
@Override @Override

View File

@ -29,20 +29,13 @@ import java.net.InetSocketAddress;
public class Netty4TcpServerChannel implements TcpServerChannel { public class Netty4TcpServerChannel implements TcpServerChannel {
private final Channel channel; private final Channel channel;
private final String profile;
private final CompletableContext<Void> closeContext = new CompletableContext<>(); private final CompletableContext<Void> closeContext = new CompletableContext<>();
Netty4TcpServerChannel(Channel channel, String profile) { Netty4TcpServerChannel(Channel channel) {
this.channel = channel; this.channel = channel;
this.profile = profile;
Netty4TcpChannel.addListener(this.channel.closeFuture(), closeContext); Netty4TcpChannel.addListener(this.channel.closeFuture(), closeContext);
} }
@Override
public String getProfile() {
return profile;
}
@Override @Override
public InetSocketAddress getLocalAddress() { public InetSocketAddress getLocalAddress() {
return (InetSocketAddress) channel.localAddress(); return (InetSocketAddress) channel.localAddress();

View File

@ -239,7 +239,7 @@ public class Netty4Transport extends TcpTransport {
@Override @Override
protected Netty4TcpServerChannel bind(String name, InetSocketAddress address) { protected Netty4TcpServerChannel bind(String name, InetSocketAddress address) {
Channel channel = serverBootstraps.get(name).bind(address).syncUninterruptibly().channel(); Channel channel = serverBootstraps.get(name).bind(address).syncUninterruptibly().channel();
Netty4TcpServerChannel esChannel = new Netty4TcpServerChannel(channel, name); Netty4TcpServerChannel esChannel = new Netty4TcpServerChannel(channel);
channel.attr(SERVER_CHANNEL_KEY).set(esChannel); channel.attr(SERVER_CHANNEL_KEY).set(esChannel);
return esChannel; return esChannel;
} }

View File

@ -22,6 +22,8 @@ package org.elasticsearch.http.nio;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
@ -33,6 +35,7 @@ import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpServerChannel; import org.elasticsearch.http.HttpServerChannel;
import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
@ -130,7 +133,11 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
@Override @Override
protected HttpServerChannel bind(InetSocketAddress socketAddress) throws IOException { protected HttpServerChannel bind(InetSocketAddress socketAddress) throws IOException {
return nioGroup.bindServerChannel(socketAddress, channelFactory); NioHttpServerChannel httpServerChannel = nioGroup.bindServerChannel(socketAddress, channelFactory);
PlainActionFuture<Void> future = PlainActionFuture.newFuture();
httpServerChannel.addBindListener(ActionListener.toBiConsumer(future));
future.actionGet();
return httpServerChannel;
} }
protected ChannelFactory<NioHttpServerChannel, NioHttpChannel> channelFactory() { protected ChannelFactory<NioHttpServerChannel, NioHttpChannel> channelFactory() {
@ -144,27 +151,29 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
private class HttpChannelFactory extends ChannelFactory<NioHttpServerChannel, NioHttpChannel> { private class HttpChannelFactory extends ChannelFactory<NioHttpServerChannel, NioHttpChannel> {
private HttpChannelFactory() { private HttpChannelFactory() {
super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize)); super(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize);
} }
@Override @Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
NioHttpChannel httpChannel = new NioHttpChannel(channel); NioHttpChannel httpChannel = new NioHttpChannel(channel);
HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this, HttpReadWriteHandler handler = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis); handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
Consumer<Exception> exceptionHandler = (e) -> onException(httpChannel, e); Consumer<Exception> exceptionHandler = (e) -> onException(httpChannel, e);
SocketChannelContext context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpReadWritePipeline, SocketChannelContext context = new BytesChannelContext(httpChannel, selector, socketConfig, exceptionHandler, handler,
new InboundChannelBuffer(pageAllocator)); new InboundChannelBuffer(pageAllocator));
httpChannel.setContext(context); httpChannel.setContext(context);
return httpChannel; return httpChannel;
} }
@Override @Override
public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException { public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
NioHttpServerChannel httpServerChannel = new NioHttpServerChannel(channel); NioHttpServerChannel httpServerChannel = new NioHttpServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> onServerException(httpServerChannel, e); Consumer<Exception> exceptionHandler = (e) -> onServerException(httpServerChannel, e);
Consumer<NioSocketChannel> acceptor = NioHttpServerTransport.this::acceptChannel; Consumer<NioSocketChannel> acceptor = NioHttpServerTransport.this::acceptChannel;
ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, acceptor, exceptionHandler); ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, socketConfig, acceptor,
exceptionHandler);
httpServerChannel.setContext(context); httpServerChannel.setContext(context);
return httpServerChannel; return httpServerChannel;
} }

View File

@ -31,22 +31,14 @@ import java.nio.channels.ServerSocketChannel;
*/ */
public class NioTcpServerChannel extends NioServerSocketChannel implements TcpServerChannel { public class NioTcpServerChannel extends NioServerSocketChannel implements TcpServerChannel {
private final String profile; public NioTcpServerChannel(ServerSocketChannel socketChannel) {
public NioTcpServerChannel(String profile, ServerSocketChannel socketChannel) {
super(socketChannel); super(socketChannel);
this.profile = profile;
} }
public void close() { public void close() {
getContext().closeChannel(); getContext().closeChannel();
} }
@Override
public String getProfile() {
return profile;
}
@Override @Override
public void addCloseListener(ActionListener<Void> listener) { public void addCloseListener(ActionListener<Void> listener) {
addCloseListener(ActionListener.toBiConsumer(listener)); addCloseListener(ActionListener.toBiConsumer(listener));

View File

@ -23,6 +23,8 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
@ -31,6 +33,7 @@ import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
@ -38,6 +41,7 @@ import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportSettings;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
@ -70,7 +74,11 @@ public class NioTransport extends TcpTransport {
@Override @Override
protected NioTcpServerChannel bind(String name, InetSocketAddress address) throws IOException { protected NioTcpServerChannel bind(String name, InetSocketAddress address) throws IOException {
TcpChannelFactory channelFactory = this.profileToChannelFactory.get(name); TcpChannelFactory channelFactory = this.profileToChannelFactory.get(name);
return nioGroup.bindServerChannel(address, channelFactory); NioTcpServerChannel serverChannel = nioGroup.bindServerChannel(address, channelFactory);
PlainActionFuture<Void> future = PlainActionFuture.newFuture();
serverChannel.addBindListener(ActionListener.toBiConsumer(future));
future.actionGet();
return serverChannel;
} }
@Override @Override
@ -85,7 +93,7 @@ public class NioTransport extends TcpTransport {
try { try {
nioGroup = groupFactory.getTransportGroup(); nioGroup = groupFactory.getTransportGroup();
ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default"); ProfileSettings clientProfileSettings = new ProfileSettings(settings, TransportSettings.DEFAULT_PROFILE);
clientChannelFactory = clientChannelFactoryFunction(clientProfileSettings); clientChannelFactory = clientChannelFactoryFunction(clientProfileSettings);
if (NetworkService.NETWORK_SERVER.get(settings)) { if (NetworkService.NETWORK_SERVER.get(settings)) {
@ -133,8 +141,9 @@ public class NioTransport extends TcpTransport {
protected abstract class TcpChannelFactory extends ChannelFactory<NioTcpServerChannel, NioTcpChannel> { protected abstract class TcpChannelFactory extends ChannelFactory<NioTcpServerChannel, NioTcpChannel> {
protected TcpChannelFactory(RawChannelFactory rawChannelFactory) { protected TcpChannelFactory(ProfileSettings profileSettings) {
super(rawChannelFactory); super(profileSettings.tcpNoDelay, profileSettings.tcpKeepAlive, profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()), Math.toIntExact(profileSettings.receiveBufferSize.getBytes()));
} }
} }
@ -144,32 +153,29 @@ public class NioTransport extends TcpTransport {
private final String profileName; private final String profileName;
private TcpChannelFactoryImpl(ProfileSettings profileSettings, boolean isClient) { private TcpChannelFactoryImpl(ProfileSettings profileSettings, boolean isClient) {
super(new RawChannelFactory(profileSettings.tcpNoDelay, super(profileSettings);
profileSettings.tcpKeepAlive,
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())));
this.isClient = isClient; this.isClient = isClient;
this.profileName = profileSettings.profileName; this.profileName = profileSettings.profileName;
} }
@Override @Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) { public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this); TcpReadWriteHandler handler = new TcpReadWriteHandler(nioChannel, NioTransport.this);
Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e); Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, BytesChannelContext context = new BytesChannelContext(nioChannel, selector, socketConfig, exceptionHandler, handler,
new InboundChannelBuffer(pageAllocator)); new InboundChannelBuffer(pageAllocator));
nioChannel.setContext(context); nioChannel.setContext(context);
return nioChannel; return nioChannel;
} }
@Override @Override
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
NioTcpServerChannel nioChannel = new NioTcpServerChannel(profileName, channel); Config.ServerSocket socketConfig) {
NioTcpServerChannel nioChannel = new NioTcpServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> onServerException(nioChannel, e); Consumer<Exception> exceptionHandler = (e) -> onServerException(nioChannel, e);
Consumer<NioSocketChannel> acceptor = NioTransport.this::acceptChannel; Consumer<NioSocketChannel> acceptor = NioTransport.this::acceptChannel;
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, acceptor, exceptionHandler); ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, socketConfig, acceptor, exceptionHandler);
nioChannel.setContext(context); nioChannel.setContext(context);
return nioChannel; return nioChannel;
} }

View File

@ -31,8 +31,8 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestEncoder; import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseDecoder;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
@ -42,14 +42,15 @@ import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.EventHandler; import org.elasticsearch.nio.EventHandler;
import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelectorGroup; import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSelectorGroup;
import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
@ -149,7 +150,8 @@ class NioHttpClient implements Closeable {
connectFuture.actionGet(); connectFuture.actionGet();
for (HttpRequest request : requests) { for (HttpRequest request : requests) {
nioSocketChannel.getContext().sendMessage(request, (v, e) -> {}); nioSocketChannel.getContext().sendMessage(request, (v, e) -> {
});
} }
if (latch.await(30L, TimeUnit.SECONDS) == false) { if (latch.await(30L, TimeUnit.SECONDS) == false) {
fail("Failed to get all expected responses."); fail("Failed to get all expected responses.");
@ -177,17 +179,17 @@ class NioHttpClient implements Closeable {
private final Collection<FullHttpResponse> content; private final Collection<FullHttpResponse> content;
private ClientChannelFactory(CountDownLatch latch, Collection<FullHttpResponse> content) { private ClientChannelFactory(CountDownLatch latch, Collection<FullHttpResponse> content) {
super(new RawChannelFactory(NetworkService.TCP_NO_DELAY.get(Settings.EMPTY), super(NetworkService.TCP_NO_DELAY.get(Settings.EMPTY),
NetworkService.TCP_KEEP_ALIVE.get(Settings.EMPTY), NetworkService.TCP_KEEP_ALIVE.get(Settings.EMPTY),
NetworkService.TCP_REUSE_ADDRESS.get(Settings.EMPTY), NetworkService.TCP_REUSE_ADDRESS.get(Settings.EMPTY),
Math.toIntExact(NetworkService.TCP_SEND_BUFFER_SIZE.get(Settings.EMPTY).getBytes()), Math.toIntExact(NetworkService.TCP_SEND_BUFFER_SIZE.get(Settings.EMPTY).getBytes()),
Math.toIntExact(NetworkService.TCP_RECEIVE_BUFFER_SIZE.get(Settings.EMPTY).getBytes()))); Math.toIntExact(NetworkService.TCP_RECEIVE_BUFFER_SIZE.get(Settings.EMPTY).getBytes()));
this.latch = latch; this.latch = latch;
this.content = content; this.content = content;
} }
@Override @Override
public NioSocketChannel createChannel(NioSelector selector, java.nio.channels.SocketChannel channel) throws IOException { public NioSocketChannel createChannel(NioSelector selector, java.nio.channels.SocketChannel channel, Config.Socket socketConfig) {
NioSocketChannel nioSocketChannel = new NioSocketChannel(channel); NioSocketChannel nioSocketChannel = new NioSocketChannel(channel);
HttpClientHandler handler = new HttpClientHandler(nioSocketChannel, latch, content); HttpClientHandler handler = new HttpClientHandler(nioSocketChannel, latch, content);
Consumer<Exception> exceptionHandler = (e) -> { Consumer<Exception> exceptionHandler = (e) -> {
@ -195,14 +197,15 @@ class NioHttpClient implements Closeable {
onException(e); onException(e);
nioSocketChannel.close(); nioSocketChannel.close();
}; };
SocketChannelContext context = new BytesChannelContext(nioSocketChannel, selector, exceptionHandler, handler, SocketChannelContext context = new BytesChannelContext(nioSocketChannel, selector, socketConfig, exceptionHandler, handler,
InboundChannelBuffer.allocatingInstance()); InboundChannelBuffer.allocatingInstance());
nioSocketChannel.setContext(context); nioSocketChannel.setContext(context);
return nioSocketChannel; return nioSocketChannel;
} }
@Override @Override
public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
throw new UnsupportedOperationException("Cannot create server channel"); throw new UnsupportedOperationException("Cannot create server channel");
} }
} }

View File

@ -21,6 +21,7 @@ package org.elasticsearch.transport.nio;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioServerSocketChannel;
@ -57,20 +58,21 @@ public class NioGroupFactoryTests extends ESTestCase {
private static class BindingFactory extends ChannelFactory<NioServerSocketChannel, NioSocketChannel> { private static class BindingFactory extends ChannelFactory<NioServerSocketChannel, NioSocketChannel> {
private BindingFactory() { private BindingFactory() {
super(new ChannelFactory.RawChannelFactory(false, false, false, -1, -1)); super(false, false, false, -1, -1);
} }
@Override @Override
public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException {
throw new IOException("boom"); throw new IOException("boom");
} }
@Override @Override
public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException { public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
NioServerSocketChannel nioChannel = new NioServerSocketChannel(channel); NioServerSocketChannel nioChannel = new NioServerSocketChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> {}; Consumer<Exception> exceptionHandler = (e) -> {};
Consumer<NioSocketChannel> acceptor = (c) -> {}; Consumer<NioSocketChannel> acceptor = (c) -> {};
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, acceptor, exceptionHandler); ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, socketConfig, acceptor, exceptionHandler);
nioChannel.setContext(context); nioChannel.setContext(context);
return nioChannel; return nioChannel;
} }

View File

@ -31,11 +31,6 @@ import java.net.InetSocketAddress;
*/ */
public interface TcpServerChannel extends CloseableChannel { public interface TcpServerChannel extends CloseableChannel {
/**
* This returns the profile for this channel.
*/
String getProfile();
/** /**
* Returns the local address for this channel. * Returns the local address for this channel.
* *

View File

@ -25,6 +25,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@ -37,6 +38,7 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.BytesWriteHandler; import org.elasticsearch.nio.BytesWriteHandler;
import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSelectorGroup; import org.elasticsearch.nio.NioSelectorGroup;
@ -91,7 +93,11 @@ public class MockNioTransport extends TcpTransport {
@Override @Override
protected MockServerChannel bind(String name, InetSocketAddress address) throws IOException { protected MockServerChannel bind(String name, InetSocketAddress address) throws IOException {
MockTcpChannelFactory channelFactory = this.profileToChannelFactory.get(name); MockTcpChannelFactory channelFactory = this.profileToChannelFactory.get(name);
return nioGroup.bindServerChannel(address, channelFactory); MockServerChannel serverChannel = nioGroup.bindServerChannel(address, channelFactory);
PlainActionFuture<Void> future = PlainActionFuture.newFuture();
serverChannel.addBindListener(ActionListener.toBiConsumer(future));
future.actionGet();
return serverChannel;
} }
@Override @Override
@ -190,17 +196,17 @@ public class MockNioTransport extends TcpTransport {
private final String profileName; private final String profileName;
private MockTcpChannelFactory(boolean isClient, ProfileSettings profileSettings, String profileName) { private MockTcpChannelFactory(boolean isClient, ProfileSettings profileSettings, String profileName) {
super(new RawChannelFactory(profileSettings.tcpNoDelay, super(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive, profileSettings.tcpKeepAlive,
profileSettings.reuseAddress, profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()), Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes()))); Math.toIntExact(profileSettings.receiveBufferSize.getBytes()));
this.isClient = isClient; this.isClient = isClient;
this.profileName = profileName; this.profileName = profileName;
} }
@Override @Override
public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel); MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel);
IntFunction<Page> pageSupplier = (length) -> { IntFunction<Page> pageSupplier = (length) -> {
if (length > PageCacheRecycler.BYTE_PAGE_SIZE) { if (length > PageCacheRecycler.BYTE_PAGE_SIZE) {
@ -211,7 +217,7 @@ public class MockNioTransport extends TcpTransport {
} }
}; };
MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this); MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this);
BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e), BytesChannelContext context = new BytesChannelContext(nioChannel, selector, socketConfig, (e) -> exceptionCaught(nioChannel, e),
readWriteHandler, new InboundChannelBuffer(pageSupplier)); readWriteHandler, new InboundChannelBuffer(pageSupplier));
nioChannel.setContext(context); nioChannel.setContext(context);
nioChannel.addConnectListener((v, e) -> { nioChannel.addConnectListener((v, e) -> {
@ -229,18 +235,19 @@ public class MockNioTransport extends TcpTransport {
} }
@Override @Override
public MockServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { public MockServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel, Config.ServerSocket socketConfig) {
MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel); MockServerChannel nioServerChannel = new MockServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> logger.error(() -> Consumer<Exception> exceptionHandler = (e) -> logger.error(() ->
new ParameterizedMessage("exception from server channel caught on transport layer [{}]", channel), e); new ParameterizedMessage("exception from server channel caught on transport layer [{}]", channel), e);
ServerChannelContext context = new ServerChannelContext(nioServerChannel, null, selector, null, ServerChannelContext context = new ServerChannelContext(nioServerChannel, this, selector, socketConfig,
exceptionHandler) { MockNioTransport.this::acceptChannel, exceptionHandler) {
@Override @Override
public void acceptChannels(Supplier<NioSelector> selectorSupplier) throws IOException { public void acceptChannels(Supplier<NioSelector> selectorSupplier) throws IOException {
int acceptCount = 0; int acceptCount = 0;
NioSocketChannel acceptedChannel; SocketChannel acceptedChannel;
while ((acceptedChannel = MockTcpChannelFactory.this.acceptNioChannel(this, selectorSupplier)) != null) { while ((acceptedChannel = accept(rawChannel)) != null) {
acceptChannel(acceptedChannel); NioSocketChannel nioChannel = MockTcpChannelFactory.this.acceptNioChannel(acceptedChannel, selectorSupplier);
acceptChannel(nioChannel);
++acceptCount; ++acceptCount;
if (acceptCount % 100 == 0) { if (acceptCount % 100 == 0) {
logger.warn("Accepted [{}] connections in a single select loop iteration on [{}]", acceptCount, channel); logger.warn("Accepted [{}] connections in a single select loop iteration on [{}]", acceptCount, channel);
@ -272,11 +279,8 @@ public class MockNioTransport extends TcpTransport {
private static class MockServerChannel extends NioServerSocketChannel implements TcpServerChannel { private static class MockServerChannel extends NioServerSocketChannel implements TcpServerChannel {
private final String profile; MockServerChannel(ServerSocketChannel channel) {
MockServerChannel(String profile, ServerSocketChannel channel) {
super(channel); super(channel);
this.profile = profile;
} }
@Override @Override
@ -284,11 +288,6 @@ public class MockNioTransport extends TcpTransport {
getContext().closeChannel(); getContext().closeChannel();
} }
@Override
public String getProfile() {
return profile;
}
@Override @Override
public void addCloseListener(ActionListener<Void> listener) { public void addCloseListener(ActionListener<Void> listener) {
addCloseListener(ActionListener.toBiConsumer(listener)); addCloseListener(ActionListener.toBiConsumer(listener));

View File

@ -13,6 +13,7 @@ import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.nio.WriteOperation;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
@ -41,15 +42,17 @@ public final class SSLChannelContext extends SocketChannelContext {
private final LinkedList<FlushOperation> encryptedFlushes = new LinkedList<>(); private final LinkedList<FlushOperation> encryptedFlushes = new LinkedList<>();
private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER; private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver, SSLChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig,
NioChannelHandler readWriteHandler, InboundChannelBuffer applicationBuffer) { Consumer<Exception> exceptionHandler, SSLDriver sslDriver, NioChannelHandler readWriteHandler,
this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(), InboundChannelBuffer applicationBuffer) {
this(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(),
applicationBuffer); applicationBuffer);
} }
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver, SSLChannelContext(NioSocketChannel channel, NioSelector selector, Config.Socket socketConfig,
NioChannelHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer) { Consumer<Exception> exceptionHandler, SSLDriver sslDriver, NioChannelHandler readWriteHandler,
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer) {
super(channel, selector, socketConfig, exceptionHandler, readWriteHandler, channelBuffer);
this.sslDriver = sslDriver; this.sslDriver = sslDriver;
this.networkReadBuffer = networkReadBuffer; this.networkReadBuffer = networkReadBuffer;
} }

View File

@ -18,6 +18,7 @@ import org.elasticsearch.http.nio.NioHttpServerChannel;
import org.elasticsearch.http.nio.NioHttpServerTransport; import org.elasticsearch.http.nio.NioHttpServerTransport;
import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
@ -82,17 +83,17 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
class SecurityHttpChannelFactory extends ChannelFactory<NioHttpServerChannel, NioHttpChannel> { class SecurityHttpChannelFactory extends ChannelFactory<NioHttpServerChannel, NioHttpChannel> {
private SecurityHttpChannelFactory() { private SecurityHttpChannelFactory() {
super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize)); super(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize);
} }
@Override @Override
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException {
NioHttpChannel httpChannel = new NioHttpChannel(channel); NioHttpChannel httpChannel = new NioHttpChannel(channel);
HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this, HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos); handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
final NioChannelHandler handler; final NioChannelHandler handler;
if (ipFilter != null) { if (ipFilter != null) {
handler = new NioIPFilter(httpHandler, httpChannel.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME); handler = new NioIPFilter(httpHandler, socketConfig.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME);
} else { } else {
handler = httpHandler; handler = httpHandler;
} }
@ -113,10 +114,10 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
} }
SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false); SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false);
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator); InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer, context = new SSLChannelContext(httpChannel, selector, socketConfig, exceptionHandler, sslDriver, handler, networkBuffer,
applicationBuffer); applicationBuffer);
} else { } else {
context = new BytesChannelContext(httpChannel, selector, exceptionHandler, handler, networkBuffer); context = new BytesChannelContext(httpChannel, selector, socketConfig, exceptionHandler, handler, networkBuffer);
} }
httpChannel.setContext(context); httpChannel.setContext(context);
@ -124,11 +125,13 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
} }
@Override @Override
public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { public NioHttpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
NioHttpServerChannel httpServerChannel = new NioHttpServerChannel(channel); NioHttpServerChannel httpServerChannel = new NioHttpServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> onServerException(httpServerChannel, e); Consumer<Exception> exceptionHandler = (e) -> onServerException(httpServerChannel, e);
Consumer<NioSocketChannel> acceptor = SecurityNioHttpServerTransport.this::acceptChannel; Consumer<NioSocketChannel> acceptor = SecurityNioHttpServerTransport.this::acceptChannel;
ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, acceptor, exceptionHandler); ServerChannelContext context = new ServerChannelContext(httpServerChannel, this, selector, socketConfig, acceptor,
exceptionHandler);
httpServerChannel.setContext(context); httpServerChannel.setContext(context);
return httpServerChannel; return httpServerChannel;

View File

@ -16,7 +16,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.NioChannelHandler;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
@ -111,9 +111,6 @@ public class SecurityNioTransport extends NioTransport {
@Override @Override
protected Function<DiscoveryNode, TcpChannelFactory> clientChannelFactoryFunction(ProfileSettings profileSettings) { protected Function<DiscoveryNode, TcpChannelFactory> clientChannelFactoryFunction(ProfileSettings profileSettings) {
return (node) -> { return (node) -> {
final ChannelFactory.RawChannelFactory rawChannelFactory = new ChannelFactory.RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive, profileSettings.reuseAddress, Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes()));
SNIHostName serverName; SNIHostName serverName;
String configuredServerName = node.getAttributes().get("server_name"); String configuredServerName = node.getAttributes().get("server_name");
if (configuredServerName != null) { if (configuredServerName != null) {
@ -125,7 +122,7 @@ public class SecurityNioTransport extends NioTransport {
} else { } else {
serverName = null; serverName = null;
} }
return new SecurityClientTcpChannelFactory(rawChannelFactory, serverName); return new SecurityClientTcpChannelFactory(profileSettings, serverName);
}; };
} }
@ -135,26 +132,18 @@ public class SecurityNioTransport extends NioTransport {
private final boolean isClient; private final boolean isClient;
private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) { private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) {
this(new RawChannelFactory(profileSettings.tcpNoDelay, super(profileSettings);
profileSettings.tcpKeepAlive, this.profileName = profileSettings.profileName;
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), profileSettings.profileName, isClient);
}
private SecurityTcpChannelFactory(RawChannelFactory rawChannelFactory, String profileName, boolean isClient) {
super(rawChannelFactory);
this.profileName = profileName;
this.isClient = isClient; this.isClient = isClient;
} }
@Override @Override
public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException {
NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this); TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this);
final NioChannelHandler handler; final NioChannelHandler handler;
if (ipFilter != null) { if (ipFilter != null) {
handler = new NioIPFilter(readWriteHandler, nioChannel.getRemoteAddress(), ipFilter, profileName); handler = new NioIPFilter(readWriteHandler, socketConfig.getRemoteAddress(), ipFilter, profileName);
} else { } else {
handler = readWriteHandler; handler = readWriteHandler;
} }
@ -163,12 +152,12 @@ public class SecurityNioTransport extends NioTransport {
SocketChannelContext context; SocketChannelContext context;
if (sslEnabled) { if (sslEnabled) {
SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient); SSLDriver sslDriver = new SSLDriver(createSSLEngine(socketConfig), pageAllocator, isClient);
InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator); InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer, context = new SSLChannelContext(nioChannel, selector, socketConfig, exceptionHandler, sslDriver, handler, networkBuffer,
applicationBuffer); applicationBuffer);
} else { } else {
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, handler, networkBuffer); context = new BytesChannelContext(nioChannel, selector, socketConfig, exceptionHandler, handler, networkBuffer);
} }
nioChannel.setContext(context); nioChannel.setContext(context);
@ -176,24 +165,25 @@ public class SecurityNioTransport extends NioTransport {
} }
@Override @Override
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException { public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
NioTcpServerChannel nioChannel = new NioTcpServerChannel(profileName, channel); Config.ServerSocket socketConfig) {
NioTcpServerChannel nioChannel = new NioTcpServerChannel(channel);
Consumer<Exception> exceptionHandler = (e) -> onServerException(nioChannel, e); Consumer<Exception> exceptionHandler = (e) -> onServerException(nioChannel, e);
Consumer<NioSocketChannel> acceptor = SecurityNioTransport.this::acceptChannel; Consumer<NioSocketChannel> acceptor = SecurityNioTransport.this::acceptChannel;
ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, acceptor, exceptionHandler); ServerChannelContext context = new ServerChannelContext(nioChannel, this, selector, socketConfig, acceptor, exceptionHandler);
nioChannel.setContext(context); nioChannel.setContext(context);
return nioChannel; return nioChannel;
} }
protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException { protected SSLEngine createSSLEngine(Config.Socket socketConfig) throws IOException {
SSLEngine sslEngine; SSLEngine sslEngine;
SSLConfiguration defaultConfig = profileConfiguration.get(TransportSettings.DEFAULT_PROFILE); SSLConfiguration defaultConfig = profileConfiguration.get(TransportSettings.DEFAULT_PROFILE);
SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig); SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig);
boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled(); boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled();
if (hostnameVerificationEnabled) { if (hostnameVerificationEnabled && socketConfig.isAccepted() == false) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) channel.getRemoteAddress(); InetSocketAddress remoteAddress = socketConfig.getRemoteAddress();
// we create the socket based on the name given. don't reverse DNS // we create the socket based on the name given. don't reverse DNS
sslEngine = sslService.createSSLEngine(sslConfig, inetSocketAddress.getHostString(), inetSocketAddress.getPort()); sslEngine = sslService.createSSLEngine(sslConfig, remoteAddress.getHostString(), remoteAddress.getPort());
} else { } else {
sslEngine = sslService.createSSLEngine(sslConfig, null, -1); sslEngine = sslService.createSSLEngine(sslConfig, null, -1);
} }
@ -205,19 +195,20 @@ public class SecurityNioTransport extends NioTransport {
private final SNIHostName serverName; private final SNIHostName serverName;
private SecurityClientTcpChannelFactory(RawChannelFactory rawChannelFactory, SNIHostName serverName) { private SecurityClientTcpChannelFactory(ProfileSettings profileSettings, SNIHostName serverName) {
super(rawChannelFactory, TransportSettings.DEFAULT_PROFILE, true); super(profileSettings, true);
this.serverName = serverName; this.serverName = serverName;
} }
@Override @Override
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel,
Config.ServerSocket socketConfig) {
throw new AssertionError("Cannot create TcpServerChannel with client factory"); throw new AssertionError("Cannot create TcpServerChannel with client factory");
} }
@Override @Override
protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException { protected SSLEngine createSSLEngine(Config.Socket socketConfig) throws IOException {
SSLEngine sslEngine = super.createSSLEngine(channel); SSLEngine sslEngine = super.createSSLEngine(socketConfig);
if (serverName != null) { if (serverName != null) {
SSLParameters sslParameters = sslEngine.getSSLParameters(); SSLParameters sslParameters = sslEngine.getSSLParameters();
sslParameters.setServerNames(Collections.singletonList(serverName)); sslParameters.setServerNames(Collections.singletonList(serverName));

View File

@ -14,6 +14,7 @@ import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page; import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.TaskScheduler; import org.elasticsearch.nio.TaskScheduler;
import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
@ -23,6 +24,7 @@ import org.mockito.stubbing.Answer;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.nio.channels.Selector; import java.nio.channels.Selector;
@ -55,6 +57,7 @@ public class SSLChannelContextTests extends ESTestCase {
private Consumer exceptionHandler; private Consumer exceptionHandler;
private SSLDriver sslDriver; private SSLDriver sslDriver;
private int messageLength; private int messageLength;
private Config.Socket socketConfig;
@Before @Before
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -73,7 +76,8 @@ public class SSLChannelContextTests extends ESTestCase {
outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {})); outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {}));
when(channel.getRawChannel()).thenReturn(rawChannel); when(channel.getRawChannel()).thenReturn(rawChannel);
exceptionHandler = mock(Consumer.class); exceptionHandler = mock(Consumer.class);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); socketConfig = new Config.Socket(randomBoolean(), randomBoolean(), randomBoolean(), -1, -1, mock(InetSocketAddress.class), false);
context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context.setSelectionKey(mock(SelectionKey.class)); context.setSelectionKey(mock(SelectionKey.class));
when(selector.isOnCurrentThread()).thenReturn(true); when(selector.isOnCurrentThread()).thenReturn(true);
@ -180,7 +184,7 @@ public class SSLChannelContextTests extends ESTestCase {
try (SocketChannel realChannel = SocketChannel.open()) { try (SocketChannel realChannel = SocketChannel.open()) {
when(channel.getRawChannel()).thenReturn(realChannel); when(channel.getRawChannel()).thenReturn(realChannel);
TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
when(channel.isOpen()).thenReturn(true); when(channel.isOpen()).thenReturn(true);
context.closeFromSelector(); context.closeFromSelector();
@ -332,8 +336,10 @@ public class SSLChannelContextTests extends ESTestCase {
try (SocketChannel realChannel = SocketChannel.open()) { try (SocketChannel realChannel = SocketChannel.open()) {
when(channel.getRawChannel()).thenReturn(realChannel); when(channel.getRawChannel()).thenReturn(realChannel);
TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context.setSelectionKey(mock(SelectionKey.class)); context.setSelectionKey(mock(SelectionKey.class));
context.closeChannel(); context.closeChannel();
ArgumentCaptor<WriteOperation> captor = ArgumentCaptor.forClass(WriteOperation.class); ArgumentCaptor<WriteOperation> captor = ArgumentCaptor.forClass(WriteOperation.class);
verify(selector).queueWrite(captor.capture()); verify(selector).queueWrite(captor.capture());
@ -373,7 +379,7 @@ public class SSLChannelContextTests extends ESTestCase {
when(selector.rawSelector()).thenReturn(realSelector); when(selector.rawSelector()).thenReturn(realSelector);
when(channel.getRawChannel()).thenReturn(realSocket); when(channel.getRawChannel()).thenReturn(realSocket);
TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer); TestReadWriteHandler readWriteHandler = new TestReadWriteHandler(readConsumer);
context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); context = new SSLChannelContext(channel, selector, socketConfig, exceptionHandler, sslDriver, readWriteHandler, channelBuffer);
context.channelActive(); context.channelActive();
verify(sslDriver).init(); verify(sslDriver).init();
} }

View File

@ -14,6 +14,7 @@ import org.elasticsearch.env.Environment;
import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.http.NullDispatcher; import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.http.nio.NioHttpChannel; import org.elasticsearch.http.nio.NioHttpChannel;
import org.elasticsearch.nio.Config;
import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -77,7 +78,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class); SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address); when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine engine = SSLEngineUtils.getSSLEngine(channel); SSLEngine engine = SSLEngineUtils.getSSLEngine(channel);
assertThat(engine.getNeedClientAuth(), is(false)); assertThat(engine.getNeedClientAuth(), is(false));
@ -99,7 +100,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class); SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address); when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine engine = SSLEngineUtils.getSSLEngine(channel); SSLEngine engine = SSLEngineUtils.getSSLEngine(channel);
assertThat(engine.getNeedClientAuth(), is(false)); assertThat(engine.getNeedClientAuth(), is(false));
assertThat(engine.getWantClientAuth(), is(true)); assertThat(engine.getWantClientAuth(), is(true));
@ -120,7 +121,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class); SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address); when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine engine = SSLEngineUtils.getSSLEngine(channel); SSLEngine engine = SSLEngineUtils.getSSLEngine(channel);
assertThat(engine.getNeedClientAuth(), is(true)); assertThat(engine.getNeedClientAuth(), is(true));
assertThat(engine.getWantClientAuth(), is(false)); assertThat(engine.getWantClientAuth(), is(false));
@ -141,7 +142,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class); SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address); when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine engine = SSLEngineUtils.getSSLEngine(channel); SSLEngine engine = SSLEngineUtils.getSSLEngine(channel);
assertThat(engine.getNeedClientAuth(), is(false)); assertThat(engine.getNeedClientAuth(), is(false));
assertThat(engine.getWantClientAuth(), is(false)); assertThat(engine.getWantClientAuth(), is(false));
@ -159,7 +160,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory(); SecurityNioHttpServerTransport.SecurityHttpChannelFactory factory = transport.channelFactory();
SocketChannel socketChannel = mock(SocketChannel.class); SocketChannel socketChannel = mock(SocketChannel.class);
when(socketChannel.getRemoteAddress()).thenReturn(address); when(socketChannel.getRemoteAddress()).thenReturn(address);
NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel); NioHttpChannel channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine defaultEngine = SSLEngineUtils.getSSLEngine(channel); SSLEngine defaultEngine = SSLEngineUtils.getSSLEngine(channel);
settings = Settings.builder() settings = Settings.builder()
@ -173,7 +174,7 @@ public class SecurityNioHttpServerTransportTests extends ESTestCase {
new NetworkService(Collections.emptyList()), mock(BigArrays.class), mock(PageCacheRecycler.class), mock(ThreadPool.class), new NetworkService(Collections.emptyList()), mock(BigArrays.class), mock(PageCacheRecycler.class), mock(ThreadPool.class),
xContentRegistry(), new NullDispatcher(), mock(IPFilter.class), sslService, nioGroupFactory); xContentRegistry(), new NullDispatcher(), mock(IPFilter.class), sslService, nioGroupFactory);
factory = transport.channelFactory(); factory = transport.channelFactory();
channel = factory.createChannel(mock(NioSelector.class), socketChannel); channel = factory.createChannel(mock(NioSelector.class), socketChannel, mock(Config.Socket.class));
SSLEngine customEngine = SSLEngineUtils.getSSLEngine(channel); SSLEngine customEngine = SSLEngineUtils.getSSLEngine(channel);
assertThat(customEngine.getEnabledProtocols(), arrayContaining("TLSv1.2")); assertThat(customEngine.getEnabledProtocols(), arrayContaining("TLSv1.2"));
assertThat(customEngine.getEnabledProtocols(), not(equalTo(defaultEngine.getEnabledProtocols()))); assertThat(customEngine.getEnabledProtocols(), not(equalTo(defaultEngine.getEnabledProtocols())));