Transition transport apis to use void listeners (#27440)

Currently we use ActionListener<TcpChannel> for connect, close, and send
message listeners in TcpTransport. However, all of the listeners have to
capture a reference to a channel in the case of the exception api being
called. This commit changes these listeners to be type <Void> as passing
the channel to onResponse is not necessary. Additionally, this change
makes it easier to integrate with low level transports (which use
different implementations of TcpChannel).
This commit is contained in:
Tim Brooks 2017-11-20 10:47:47 -07:00 committed by GitHub
parent d02f45f694
commit 0a8f48d592
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 104 additions and 110 deletions

View File

@ -30,7 +30,6 @@ import org.elasticsearch.common.unit.TimeValue;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -61,7 +60,7 @@ public interface TcpChannel extends Releasable {
* *
* @param listener to be executed * @param listener to be executed
*/ */
void addCloseListener(ActionListener<TcpChannel> listener); void addCloseListener(ActionListener<Void> listener);
/** /**
@ -94,7 +93,7 @@ public interface TcpChannel extends Releasable {
* @param reference to send to channel * @param reference to send to channel
* @param listener to execute upon send completion * @param listener to execute upon send completion
*/ */
void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener); void sendMessage(BytesReference reference, ActionListener<Void> listener);
/** /**
* Closes the channel. * Closes the channel.
@ -114,10 +113,10 @@ public interface TcpChannel extends Releasable {
*/ */
static <C extends TcpChannel> void closeChannels(List<C> channels, boolean blocking) { static <C extends TcpChannel> void closeChannels(List<C> channels, boolean blocking) {
if (blocking) { if (blocking) {
ArrayList<ActionFuture<TcpChannel>> futures = new ArrayList<>(channels.size()); ArrayList<ActionFuture<Void>> futures = new ArrayList<>(channels.size());
for (final C channel : channels) { for (final C channel : channels) {
if (channel.isOpen()) { if (channel.isOpen()) {
PlainActionFuture<TcpChannel> closeFuture = PlainActionFuture.newFuture(); PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture();
channel.addCloseListener(closeFuture); channel.addCloseListener(closeFuture);
channel.close(); channel.close();
futures.add(closeFuture); futures.add(closeFuture);
@ -136,15 +135,14 @@ public interface TcpChannel extends Releasable {
* @param discoveryNode the node for the pending connections * @param discoveryNode the node for the pending connections
* @param connectionFutures representing the pending connections * @param connectionFutures representing the pending connections
* @param connectTimeout to wait for a connection * @param connectTimeout to wait for a connection
* @param <C> the type of channel
* @throws ConnectTransportException if one of the connections fails * @throws ConnectTransportException if one of the connections fails
*/ */
static <C extends TcpChannel> void awaitConnected(DiscoveryNode discoveryNode, List<ActionFuture<C>> connectionFutures, static void awaitConnected(DiscoveryNode discoveryNode, List<ActionFuture<Void>> connectionFutures, TimeValue connectTimeout)
TimeValue connectTimeout) throws ConnectTransportException { throws ConnectTransportException {
Exception connectionException = null; Exception connectionException = null;
boolean allConnected = true; boolean allConnected = true;
for (ActionFuture<C> connectionFuture : connectionFutures) { for (ActionFuture<Void> connectionFuture : connectionFutures) {
try { try {
connectionFuture.get(connectTimeout.getMillis(), TimeUnit.MILLISECONDS); connectionFuture.get(connectTimeout.getMillis(), TimeUnit.MILLISECONDS);
} catch (TimeoutException e) { } catch (TimeoutException e) {
@ -169,8 +167,8 @@ public interface TcpChannel extends Releasable {
} }
} }
static void blockOnFutures(List<ActionFuture<TcpChannel>> futures) { static void blockOnFutures(List<ActionFuture<Void>> futures) {
for (ActionFuture<TcpChannel> future : futures) { for (ActionFuture<Void> future : futures) {
try { try {
future.get(); future.get();
} catch (ExecutionException e) { } catch (ExecutionException e) {

View File

@ -343,7 +343,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
for (TcpChannel channel : channels.getChannels()) { for (TcpChannel channel : channels.getChannels()) {
internalSendMessage(channel, pingHeader, new SendMetricListener(pingHeader.length()) { internalSendMessage(channel, pingHeader, new SendMetricListener(pingHeader.length()) {
@Override @Override
protected void innerInnerOnResponse(TcpChannel channel) { protected void innerInnerOnResponse(Void v) {
successfulPings.inc(); successfulPings.inc();
} }
@ -595,10 +595,10 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
int numConnections = connectionProfile.getNumConnections(); int numConnections = connectionProfile.getNumConnections();
assert numConnections > 0 : "A connection profile must be configured with at least one connection"; assert numConnections > 0 : "A connection profile must be configured with at least one connection";
List<TcpChannel> channels = new ArrayList<>(numConnections); List<TcpChannel> channels = new ArrayList<>(numConnections);
List<ActionFuture<TcpChannel>> connectionFutures = new ArrayList<>(numConnections); List<ActionFuture<Void>> connectionFutures = new ArrayList<>(numConnections);
for (int i = 0; i < numConnections; ++i) { for (int i = 0; i < numConnections; ++i) {
try { try {
PlainActionFuture<TcpChannel> connectFuture = PlainActionFuture.newFuture(); PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();
connectionFutures.add(connectFuture); connectionFutures.add(connectFuture);
TcpChannel channel = initiateChannel(node, connectionProfile.getConnectTimeout(), connectFuture); TcpChannel channel = initiateChannel(node, connectionProfile.getConnectTimeout(), connectFuture);
channels.add(channel); channels.add(channel);
@ -940,7 +940,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
for (Map.Entry<String, List<TcpChannel>> entry : serverChannels.entrySet()) { for (Map.Entry<String, List<TcpChannel>> entry : serverChannels.entrySet()) {
String profile = entry.getKey(); String profile = entry.getKey();
List<TcpChannel> channels = entry.getValue(); List<TcpChannel> channels = entry.getValue();
ActionListener<TcpChannel> closeFailLogger = ActionListener.wrap(c -> {}, ActionListener<Void> closeFailLogger = ActionListener.wrap(c -> {},
e -> logger.warn(() -> new ParameterizedMessage("Error closing serverChannel for profile [{}]", profile), e)); e -> logger.warn(() -> new ParameterizedMessage("Error closing serverChannel for profile [{}]", profile), e));
channels.forEach(c -> c.addCloseListener(closeFailLogger)); channels.forEach(c -> c.addCloseListener(closeFailLogger));
TcpChannel.closeChannels(channels, true); TcpChannel.closeChannels(channels, true);
@ -1016,7 +1016,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
BytesArray message = new BytesArray(e.getMessage().getBytes(StandardCharsets.UTF_8)); BytesArray message = new BytesArray(e.getMessage().getBytes(StandardCharsets.UTF_8));
final SendMetricListener closeChannel = new SendMetricListener(message.length()) { final SendMetricListener closeChannel = new SendMetricListener(message.length()) {
@Override @Override
protected void innerInnerOnResponse(TcpChannel channel) { protected void innerInnerOnResponse(Void v) {
TcpChannel.closeChannel(channel, false); TcpChannel.closeChannel(channel, false);
} }
@ -1060,7 +1060,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
* @return the pending connection * @return the pending connection
* @throws IOException if an I/O exception occurs while opening the channel * @throws IOException if an I/O exception occurs while opening the channel
*/ */
protected abstract TcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<TcpChannel> connectListener) protected abstract TcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException; throws IOException;
/** /**
@ -1686,7 +1686,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
/** /**
* This listener increments the transmitted bytes metric on success. * This listener increments the transmitted bytes metric on success.
*/ */
private abstract class SendMetricListener extends NotifyOnceListener<TcpChannel> { private abstract class SendMetricListener extends NotifyOnceListener<Void> {
private final long messageSize; private final long messageSize;
private SendMetricListener(long messageSize) { private SendMetricListener(long messageSize) {
@ -1694,12 +1694,12 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
} }
@Override @Override
protected final void innerOnResponse(org.elasticsearch.transport.TcpChannel object) { protected final void innerOnResponse(Void object) {
transmittedBytesMetric.inc(messageSize); transmittedBytesMetric.inc(messageSize);
innerInnerOnResponse(object); innerInnerOnResponse(object);
} }
protected abstract void innerInnerOnResponse(org.elasticsearch.transport.TcpChannel object); protected abstract void innerInnerOnResponse(Void object);
} }
private final class SendListener extends SendMetricListener { private final class SendListener extends SendMetricListener {
@ -1715,7 +1715,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
} }
@Override @Override
protected void innerInnerOnResponse(TcpChannel channel) { protected void innerInnerOnResponse(Void v) {
release(); release();
} }

View File

@ -185,8 +185,8 @@ public class TcpTransportTests extends ESTestCase {
} }
@Override @Override
protected FakeChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, protected FakeChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
ActionListener<TcpChannel> connectListener) throws IOException { throws IOException {
return new FakeChannel(messageCaptor); return new FakeChannel(messageCaptor);
} }
@ -251,7 +251,7 @@ public class TcpTransportTests extends ESTestCase {
} }
@Override @Override
public void addCloseListener(ActionListener<TcpChannel> listener) { public void addCloseListener(ActionListener<Void> listener) {
} }
@Override @Override
@ -269,7 +269,7 @@ public class TcpTransportTests extends ESTestCase {
} }
@Override @Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) { public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
messageCaptor.set(reference); messageCaptor.set(reference);
} }
} }

View File

@ -249,7 +249,7 @@ public class Netty4Transport extends TcpTransport {
} }
@Override @Override
protected NettyTcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<TcpChannel> listener) protected NettyTcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> listener)
throws IOException { throws IOException {
ChannelFuture channelFuture = bootstrap.connect(node.getAddress().address()); ChannelFuture channelFuture = bootstrap.connect(node.getAddress().address());
Channel channel = channelFuture.channel(); Channel channel = channelFuture.channel();
@ -264,7 +264,7 @@ public class Netty4Transport extends TcpTransport {
channelFuture.addListener(f -> { channelFuture.addListener(f -> {
if (f.isSuccess()) { if (f.isSuccess()) {
listener.onResponse(nettyChannel); listener.onResponse(null);
} else { } else {
Throwable cause = f.cause(); Throwable cause = f.cause();
if (cause instanceof Error) { if (cause instanceof Error) {

View File

@ -34,13 +34,13 @@ import java.util.concurrent.CompletableFuture;
public class NettyTcpChannel implements TcpChannel { public class NettyTcpChannel implements TcpChannel {
private final Channel channel; private final Channel channel;
private final CompletableFuture<TcpChannel> closeContext = new CompletableFuture<>(); private final CompletableFuture<Void> closeContext = new CompletableFuture<>();
NettyTcpChannel(Channel channel) { NettyTcpChannel(Channel channel) {
this.channel = channel; this.channel = channel;
this.channel.closeFuture().addListener(f -> { this.channel.closeFuture().addListener(f -> {
if (f.isSuccess()) { if (f.isSuccess()) {
closeContext.complete(this); closeContext.complete(null);
} else { } else {
Throwable cause = f.cause(); Throwable cause = f.cause();
if (cause instanceof Error) { if (cause instanceof Error) {
@ -59,7 +59,7 @@ public class NettyTcpChannel implements TcpChannel {
} }
@Override @Override
public void addCloseListener(ActionListener<TcpChannel> listener) { public void addCloseListener(ActionListener<Void> listener) {
closeContext.whenComplete(ActionListener.toBiConsumer(listener)); closeContext.whenComplete(ActionListener.toBiConsumer(listener));
} }
@ -79,11 +79,11 @@ public class NettyTcpChannel implements TcpChannel {
} }
@Override @Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) { public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
final ChannelFuture future = channel.writeAndFlush(Netty4Utils.toByteBuf(reference)); final ChannelFuture future = channel.writeAndFlush(Netty4Utils.toByteBuf(reference));
future.addListener(f -> { future.addListener(f -> {
if (f.isSuccess()) { if (f.isSuccess()) {
listener.onResponse(this); listener.onResponse(null);
} else { } else {
final Throwable cause = f.cause(); final Throwable cause = f.cause();
Netty4Utils.maybeDie(cause); Netty4Utils.maybeDie(cause);

View File

@ -171,7 +171,7 @@ public class MockTcpTransport extends TcpTransport {
} }
@Override @Override
protected MockChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<TcpChannel> connectListener) protected MockChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException { throws IOException {
InetSocketAddress address = node.getAddress().address(); InetSocketAddress address = node.getAddress().address();
final MockSocket socket = new MockSocket(); final MockSocket socket = new MockSocket();
@ -186,7 +186,7 @@ public class MockTcpTransport extends TcpTransport {
MockChannel channel = new MockChannel(socket, address, "none", (c) -> {}); MockChannel channel = new MockChannel(socket, address, "none", (c) -> {});
channel.loopRead(executor); channel.loopRead(executor);
success = true; success = true;
connectListener.onResponse(channel); connectListener.onResponse(null);
return channel; return channel;
} finally { } finally {
if (success == false) { if (success == false) {
@ -231,7 +231,7 @@ public class MockTcpTransport extends TcpTransport {
private final String profile; private final String profile;
private final CancellableThreads cancellableThreads = new CancellableThreads(); private final CancellableThreads cancellableThreads = new CancellableThreads();
private final Closeable onClose; private final Closeable onClose;
private final CompletableFuture<TcpChannel> closeFuture = new CompletableFuture<>(); private final CompletableFuture<Void> closeFuture = new CompletableFuture<>();
/** /**
* Constructs a new MockChannel instance intended for handling the actual incoming / outgoing traffic. * Constructs a new MockChannel instance intended for handling the actual incoming / outgoing traffic.
@ -356,14 +356,14 @@ public class MockTcpTransport extends TcpTransport {
public void close() { public void close() {
try { try {
close0(); close0();
closeFuture.complete(this); closeFuture.complete(null);
} catch (IOException e) { } catch (IOException e) {
closeFuture.completeExceptionally(e); closeFuture.completeExceptionally(e);
} }
} }
@Override @Override
public void addCloseListener(ActionListener<TcpChannel> listener) { public void addCloseListener(ActionListener<Void> listener) {
closeFuture.whenComplete(ActionListener.toBiConsumer(listener)); closeFuture.whenComplete(ActionListener.toBiConsumer(listener));
} }
@ -386,14 +386,14 @@ public class MockTcpTransport extends TcpTransport {
} }
@Override @Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) { public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
try { try {
synchronized (this) { synchronized (this) {
OutputStream outputStream = new BufferedOutputStream(activeChannel.getOutputStream()); OutputStream outputStream = new BufferedOutputStream(activeChannel.getOutputStream());
reference.writeTo(outputStream); reference.writeTo(outputStream);
outputStream.flush(); outputStream.flush();
} }
listener.onResponse(this); listener.onResponse(null);
} catch (IOException e) { } catch (IOException e) {
listener.onFailure(e); listener.onFailure(e);
onException(this, e); onException(this, e);

View File

@ -20,7 +20,6 @@
package org.elasticsearch.transport.nio; package org.elasticsearch.transport.nio;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@ -32,7 +31,6 @@ import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transports; import org.elasticsearch.transport.Transports;
import org.elasticsearch.transport.nio.channel.ChannelFactory; import org.elasticsearch.transport.nio.channel.ChannelFactory;
@ -95,22 +93,11 @@ public class NioTransport extends TcpTransport {
} }
@Override @Override
protected NioChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<TcpChannel> connectListener) protected NioChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException { throws IOException {
NioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get()); NioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get());
openChannels.clientChannelOpened(channel); openChannels.clientChannelOpened(channel);
// TODO: Temporary conversion due to types channel.addConnectListener(connectListener);
channel.addConnectListener(new ActionListener<NioChannel>() {
@Override
public void onResponse(NioChannel nioChannel) {
connectListener.onResponse(nioChannel);
}
@Override
public void onFailure(Exception e) {
connectListener.onFailure(e);
}
});
return channel; return channel;
} }

View File

@ -24,7 +24,6 @@ import org.apache.lucene.util.BytesRefIterator;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import java.io.IOException; import java.io.IOException;
@ -33,10 +32,10 @@ import java.util.ArrayList;
public class WriteOperation { public class WriteOperation {
private final NioSocketChannel channel; private final NioSocketChannel channel;
private final ActionListener<NioChannel> listener; private final ActionListener<Void> listener;
private final NetworkBytesReference[] references; private final NetworkBytesReference[] references;
public WriteOperation(NioSocketChannel channel, BytesReference bytesReference, ActionListener<NioChannel> listener) { public WriteOperation(NioSocketChannel channel, BytesReference bytesReference, ActionListener<Void> listener) {
this.channel = channel; this.channel = channel;
this.listener = listener; this.listener = listener;
this.references = toArray(bytesReference); this.references = toArray(bytesReference);
@ -46,7 +45,7 @@ public class WriteOperation {
return references; return references;
} }
public ActionListener<NioChannel> getListener() { public ActionListener<Void> getListener() {
return listener; return listener;
} }

View File

@ -20,7 +20,6 @@
package org.elasticsearch.transport.nio.channel; package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.nio.ESSelector; import org.elasticsearch.transport.nio.ESSelector;
import java.io.IOException; import java.io.IOException;
@ -58,7 +57,7 @@ public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkCh
private final InetSocketAddress localAddress; private final InetSocketAddress localAddress;
private final String profile; private final String profile;
private final CompletableFuture<TcpChannel> closeContext = new CompletableFuture<>(); private final CompletableFuture<Void> closeContext = new CompletableFuture<>();
private final ESSelector selector; private final ESSelector selector;
private SelectionKey selectionKey; private SelectionKey selectionKey;
@ -111,7 +110,7 @@ public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkCh
if (closeContext.isDone() == false) { if (closeContext.isDone() == false) {
try { try {
closeRawChannel(); closeRawChannel();
closeContext.complete(this); closeContext.complete(null);
} catch (IOException e) { } catch (IOException e) {
closeContext.completeExceptionally(e); closeContext.completeExceptionally(e);
throw e; throw e;
@ -156,7 +155,7 @@ public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkCh
} }
@Override @Override
public void addCloseListener(ActionListener<TcpChannel> listener) { public void addCloseListener(ActionListener<Void> listener) {
closeContext.whenComplete(ActionListener.toBiConsumer(listener)); closeContext.whenComplete(ActionListener.toBiConsumer(listener));
} }

View File

@ -21,12 +21,10 @@ package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.AcceptingSelector;
import java.io.IOException; import java.io.IOException;
import java.nio.channels.ServerSocketChannel; import java.nio.channels.ServerSocketChannel;
import java.util.concurrent.Future;
public class NioServerSocketChannel extends AbstractNioChannel<ServerSocketChannel> { public class NioServerSocketChannel extends AbstractNioChannel<ServerSocketChannel> {
@ -43,7 +41,7 @@ public class NioServerSocketChannel extends AbstractNioChannel<ServerSocketChann
} }
@Override @Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) { public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
throw new UnsupportedOperationException("Cannot send a message to a server channel."); throw new UnsupportedOperationException("Cannot send a message to a server channel.");
} }

View File

@ -21,7 +21,6 @@ package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.nio.NetworkBytesReference; import org.elasticsearch.transport.nio.NetworkBytesReference;
import org.elasticsearch.transport.nio.SocketSelector; import org.elasticsearch.transport.nio.SocketSelector;
@ -36,7 +35,7 @@ import java.util.concurrent.CompletableFuture;
public class NioSocketChannel extends AbstractNioChannel<SocketChannel> { public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
private final InetSocketAddress remoteAddress; private final InetSocketAddress remoteAddress;
private final CompletableFuture<NioChannel> connectContext = new CompletableFuture<>(); private final CompletableFuture<Void> connectContext = new CompletableFuture<>();
private final SocketSelector socketSelector; private final SocketSelector socketSelector;
private WriteContext writeContext; private WriteContext writeContext;
private ReadContext readContext; private ReadContext readContext;
@ -49,19 +48,8 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
} }
@Override @Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) { public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
// TODO: Temporary conversion due to types writeContext.sendMessage(reference, listener);
writeContext.sendMessage(reference, new ActionListener<NioChannel>() {
@Override
public void onResponse(NioChannel nioChannel) {
listener.onResponse(nioChannel);
}
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
});
} }
@Override @Override
@ -169,12 +157,12 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
isConnected = internalFinish(); isConnected = internalFinish();
} }
if (isConnected) { if (isConnected) {
connectContext.complete(this); connectContext.complete(null);
} }
return isConnected; return isConnected;
} }
public void addConnectListener(ActionListener<NioChannel> listener) { public void addConnectListener(ActionListener<Void> listener) {
connectContext.whenComplete(ActionListener.toBiConsumer(listener)); connectContext.whenComplete(ActionListener.toBiConsumer(listener));
} }

View File

@ -38,7 +38,7 @@ public class TcpWriteContext implements WriteContext {
} }
@Override @Override
public void sendMessage(BytesReference reference, ActionListener<NioChannel> listener) { public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
if (channel.isWritable() == false) { if (channel.isWritable() == false) {
listener.onFailure(new ClosedChannelException()); listener.onFailure(new ClosedChannelException());
return; return;
@ -96,7 +96,7 @@ public class TcpWriteContext implements WriteContext {
} }
if (headOp.isFullyFlushed()) { if (headOp.isFullyFlushed()) {
headOp.getListener().onResponse(channel); headOp.getListener().onResponse(null);
} else { } else {
queued.push(headOp); queued.push(headOp);
} }

View File

@ -27,7 +27,7 @@ import java.io.IOException;
public interface WriteContext { public interface WriteContext {
void sendMessage(BytesReference reference, ActionListener<NioChannel> listener); void sendMessage(BytesReference reference, ActionListener<Void> listener);
void queueWriteOperations(WriteOperation writeOperation); void queueWriteOperations(WriteOperation writeOperation);

View File

@ -53,7 +53,7 @@ public class SocketSelectorTests extends ESTestCase {
private NioSocketChannel channel; private NioSocketChannel channel;
private TestSelectionKey selectionKey; private TestSelectionKey selectionKey;
private WriteContext writeContext; private WriteContext writeContext;
private ActionListener<NioChannel> listener; private ActionListener<Void> listener;
private NetworkBytesReference bufferReference = NetworkBytesReference.wrap(new BytesArray(new byte[1])); private NetworkBytesReference bufferReference = NetworkBytesReference.wrap(new BytesArray(new byte[1]));
private Selector rawSelector; private Selector rawSelector;

View File

@ -36,7 +36,7 @@ import static org.mockito.Mockito.when;
public class WriteOperationTests extends ESTestCase { public class WriteOperationTests extends ESTestCase {
private NioSocketChannel channel; private NioSocketChannel channel;
private ActionListener<NioChannel> listener; private ActionListener<Void> listener;
@Before @Before
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")

View File

@ -20,6 +20,7 @@
package org.elasticsearch.transport.nio.channel; package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.AcceptingSelector;
@ -61,26 +62,39 @@ public class NioServerSocketChannelTests extends ESTestCase {
} }
public void testClose() throws Exception { public void testClose() throws Exception {
AtomicReference<TcpChannel> ref = new AtomicReference<>(); AtomicBoolean isClosed = new AtomicBoolean(false);
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
NioChannel channel = new DoNotCloseServerChannel("nio", mock(ServerSocketChannel.class), mock(ChannelFactory.class), selector); NioChannel channel = new DoNotCloseServerChannel("nio", mock(ServerSocketChannel.class), mock(ChannelFactory.class), selector);
Consumer<TcpChannel> listener = (c) -> {
ref.set(c); channel.addCloseListener(new ActionListener<Void>() {
@Override
public void onResponse(Void o) {
isClosed.set(true);
latch.countDown(); latch.countDown();
}; }
channel.addCloseListener(ActionListener.wrap(listener::accept, (e) -> listener.accept(channel)));
@Override
public void onFailure(Exception e) {
isClosed.set(true);
latch.countDown();
}
});
assertTrue(channel.isOpen()); assertTrue(channel.isOpen());
assertFalse(closedRawChannel.get()); assertFalse(closedRawChannel.get());
assertFalse(isClosed.get());
TcpChannel.closeChannel(channel, true); PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture();
channel.addCloseListener(closeFuture);
channel.close();
closeFuture.actionGet();
assertTrue(closedRawChannel.get()); assertTrue(closedRawChannel.get());
assertFalse(channel.isOpen()); assertFalse(channel.isOpen());
latch.await(); latch.await();
assertSame(channel, ref.get()); assertTrue(isClosed.get());
} }
private class DoNotCloseServerChannel extends DoNotRegisterServerChannel { private class DoNotCloseServerChannel extends DoNotRegisterServerChannel {

View File

@ -68,29 +68,40 @@ public class NioSocketChannelTests extends ESTestCase {
} }
public void testClose() throws Exception { public void testClose() throws Exception {
AtomicReference<TcpChannel> ref = new AtomicReference<>(); AtomicBoolean isClosed = new AtomicBoolean(false);
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
NioSocketChannel socketChannel = new DoNotCloseChannel(NioChannel.CLIENT, mock(SocketChannel.class), selector); NioSocketChannel socketChannel = new DoNotCloseChannel(NioChannel.CLIENT, mock(SocketChannel.class), selector);
openChannels.clientChannelOpened(socketChannel); openChannels.clientChannelOpened(socketChannel);
socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class));
Consumer<TcpChannel> listener = (c) -> { socketChannel.addCloseListener(new ActionListener<Void>() {
ref.set(c); @Override
public void onResponse(Void o) {
isClosed.set(true);
latch.countDown(); latch.countDown();
}; }
socketChannel.addCloseListener(ActionListener.wrap(listener::accept, (e) -> listener.accept(socketChannel))); @Override
public void onFailure(Exception e) {
isClosed.set(true);
latch.countDown();
}
});
assertTrue(socketChannel.isOpen()); assertTrue(socketChannel.isOpen());
assertFalse(closedRawChannel.get()); assertFalse(closedRawChannel.get());
assertFalse(isClosed.get());
assertTrue(openChannels.getClientChannels().containsKey(socketChannel)); assertTrue(openChannels.getClientChannels().containsKey(socketChannel));
TcpChannel.closeChannel(socketChannel, true); PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture();
socketChannel.addCloseListener(closeFuture);
socketChannel.close();
closeFuture.actionGet();
assertTrue(closedRawChannel.get()); assertTrue(closedRawChannel.get());
assertFalse(socketChannel.isOpen()); assertFalse(socketChannel.isOpen());
assertFalse(openChannels.getClientChannels().containsKey(socketChannel)); assertFalse(openChannels.getClientChannels().containsKey(socketChannel));
latch.await(); latch.await();
assertSame(socketChannel, ref.get()); assertTrue(isClosed.get());
} }
public void testConnectSucceeds() throws Exception { public void testConnectSucceeds() throws Exception {
@ -100,7 +111,7 @@ public class NioSocketChannelTests extends ESTestCase {
socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class));
selector.scheduleForRegistration(socketChannel); selector.scheduleForRegistration(socketChannel);
PlainActionFuture<NioChannel> connectFuture = PlainActionFuture.newFuture(); PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();
socketChannel.addConnectListener(connectFuture); socketChannel.addConnectListener(connectFuture);
connectFuture.get(100, TimeUnit.SECONDS); connectFuture.get(100, TimeUnit.SECONDS);
@ -116,7 +127,7 @@ public class NioSocketChannelTests extends ESTestCase {
socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class));
selector.scheduleForRegistration(socketChannel); selector.scheduleForRegistration(socketChannel);
PlainActionFuture<NioChannel> connectFuture = PlainActionFuture.newFuture(); PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();
socketChannel.addConnectListener(connectFuture); socketChannel.addConnectListener(connectFuture);
ExecutionException e = expectThrows(ExecutionException.class, () -> connectFuture.get(100, TimeUnit.SECONDS)); ExecutionException e = expectThrows(ExecutionException.class, () -> connectFuture.get(100, TimeUnit.SECONDS));
assertTrue(e.getCause() instanceof IOException); assertTrue(e.getCause() instanceof IOException);

View File

@ -40,7 +40,7 @@ import static org.mockito.Mockito.when;
public class TcpWriteContextTests extends ESTestCase { public class TcpWriteContextTests extends ESTestCase {
private SocketSelector selector; private SocketSelector selector;
private ActionListener<NioChannel> listener; private ActionListener<Void> listener;
private TcpWriteContext writeContext; private TcpWriteContext writeContext;
private NioSocketChannel channel; private NioSocketChannel channel;
@ -136,7 +136,7 @@ public class TcpWriteContextTests extends ESTestCase {
writeContext.flushChannel(); writeContext.flushChannel();
verify(writeOperation).flush(); verify(writeOperation).flush();
verify(listener).onResponse(channel); verify(listener).onResponse(null);
assertFalse(writeContext.hasQueuedWriteOps()); assertFalse(writeContext.hasQueuedWriteOps());
} }
@ -151,7 +151,7 @@ public class TcpWriteContextTests extends ESTestCase {
when(writeOperation.isFullyFlushed()).thenReturn(false); when(writeOperation.isFullyFlushed()).thenReturn(false);
writeContext.flushChannel(); writeContext.flushChannel();
verify(listener, times(0)).onResponse(channel); verify(listener, times(0)).onResponse(null);
assertTrue(writeContext.hasQueuedWriteOps()); assertTrue(writeContext.hasQueuedWriteOps());
} }
@ -173,7 +173,7 @@ public class TcpWriteContextTests extends ESTestCase {
when(writeOperation2.isFullyFlushed()).thenReturn(false); when(writeOperation2.isFullyFlushed()).thenReturn(false);
writeContext.flushChannel(); writeContext.flushChannel();
verify(listener).onResponse(channel); verify(listener).onResponse(null);
verify(listener2, times(0)).onResponse(channel); verify(listener2, times(0)).onResponse(channel);
assertTrue(writeContext.hasQueuedWriteOps()); assertTrue(writeContext.hasQueuedWriteOps());
@ -181,7 +181,7 @@ public class TcpWriteContextTests extends ESTestCase {
writeContext.flushChannel(); writeContext.flushChannel();
verify(listener2).onResponse(channel); verify(listener2).onResponse(null);
assertFalse(writeContext.hasQueuedWriteOps()); assertFalse(writeContext.hasQueuedWriteOps());
} }