Remove parameterization from TcpTransport (#27407)

This commit is a follow up to the work completed in #27132. Essentially
it transitions two more methods (sendMessage and getLocalAddress) from
Transport to TcpChannel. With this change, there is no longer a need for
TcpTransport to be aware of the specific type of channel a transport
returns. So that class is no longer parameterized by channel type.
This commit is contained in:
Tim Brooks 2017-11-16 11:19:36 -07:00 committed by GitHub
parent 35a5922927
commit 80ef9bbdb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 239 additions and 206 deletions

View File

@ -19,19 +19,19 @@
package org.elasticsearch.transport; package org.elasticsearch.transport;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -80,6 +80,22 @@ public interface TcpChannel extends Releasable {
*/ */
boolean isOpen(); boolean isOpen();
/**
* Returns the local address for this channel.
*
* @return the local address of this channel.
*/
InetSocketAddress getLocalAddress();
/**
* Sends a tcp message to the channel. The listener will be executed once the send process has been
* completed.
*
* @param reference to send to channel
* @param listener to execute upon send completion
*/
void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener);
/** /**
* Closes the channel. * Closes the channel.
* *

View File

@ -118,7 +118,7 @@ import static org.elasticsearch.common.transport.NetworkExceptionHelper.isCloseC
import static org.elasticsearch.common.transport.NetworkExceptionHelper.isConnectException; import static org.elasticsearch.common.transport.NetworkExceptionHelper.isConnectException;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractLifecycleComponent implements Transport { public abstract class TcpTransport extends AbstractLifecycleComponent implements Transport {
public static final String TRANSPORT_SERVER_WORKER_THREAD_NAME_PREFIX = "transport_server_worker"; public static final String TRANSPORT_SERVER_WORKER_THREAD_NAME_PREFIX = "transport_server_worker";
public static final String TRANSPORT_CLIENT_BOSS_THREAD_NAME_PREFIX = "transport_client_boss"; public static final String TRANSPORT_CLIENT_BOSS_THREAD_NAME_PREFIX = "transport_client_boss";
@ -199,8 +199,8 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
protected final ConcurrentMap<DiscoveryNode, NodeChannels> connectedNodes = newConcurrentMap(); protected final ConcurrentMap<DiscoveryNode, NodeChannels> connectedNodes = newConcurrentMap();
protected final ConcurrentMap<String, BoundTransportAddress> profileBoundAddresses = newConcurrentMap(); protected final ConcurrentMap<String, BoundTransportAddress> profileBoundAddresses = newConcurrentMap();
private final Map<String, List<Channel>> serverChannels = newConcurrentMap(); private final Map<String, List<TcpChannel>> serverChannels = newConcurrentMap();
private final Set<Channel> acceptedChannels = Collections.newSetFromMap(new ConcurrentHashMap<>()); private final Set<TcpChannel> acceptedChannels = Collections.newSetFromMap(new ConcurrentHashMap<>());
protected final KeyedLock<String> connectionLock = new KeyedLock<>(); protected final KeyedLock<String> connectionLock = new KeyedLock<>();
private final NamedWriteableRegistry namedWriteableRegistry; private final NamedWriteableRegistry namedWriteableRegistry;
@ -340,10 +340,10 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
for (Map.Entry<DiscoveryNode, NodeChannels> entry : connectedNodes.entrySet()) { for (Map.Entry<DiscoveryNode, NodeChannels> entry : connectedNodes.entrySet()) {
DiscoveryNode node = entry.getKey(); DiscoveryNode node = entry.getKey();
NodeChannels channels = entry.getValue(); NodeChannels channels = entry.getValue();
for (Channel channel : channels.getChannels()) { for (TcpChannel channel : channels.getChannels()) {
internalSendMessage(channel, pingHeader, new SendMetricListener<Channel>(pingHeader.length()) { internalSendMessage(channel, pingHeader, new SendMetricListener(pingHeader.length()) {
@Override @Override
protected void innerInnerOnResponse(Channel channel) { protected void innerInnerOnResponse(TcpChannel channel) {
successfulPings.inc(); successfulPings.inc();
} }
@ -397,12 +397,12 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
public final class NodeChannels implements Connection { public final class NodeChannels implements Connection {
private final Map<TransportRequestOptions.Type, ConnectionProfile.ConnectionTypeHandle> typeMapping; private final Map<TransportRequestOptions.Type, ConnectionProfile.ConnectionTypeHandle> typeMapping;
private final List<Channel> channels; private final List<TcpChannel> channels;
private final DiscoveryNode node; private final DiscoveryNode node;
private final AtomicBoolean closed = new AtomicBoolean(false); private final AtomicBoolean closed = new AtomicBoolean(false);
private final Version version; private final Version version;
NodeChannels(DiscoveryNode node, List<Channel> channels, ConnectionProfile connectionProfile, Version handshakeVersion) { NodeChannels(DiscoveryNode node, List<TcpChannel> channels, ConnectionProfile connectionProfile, Version handshakeVersion) {
this.node = node; this.node = node;
this.channels = Collections.unmodifiableList(channels); this.channels = Collections.unmodifiableList(channels);
assert channels.size() == connectionProfile.getNumConnections() : "expected channels size to be == " assert channels.size() == connectionProfile.getNumConnections() : "expected channels size to be == "
@ -420,11 +420,11 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
return version; return version;
} }
public List<Channel> getChannels() { public List<TcpChannel> getChannels() {
return channels; return channels;
} }
public Channel channel(TransportRequestOptions.Type type) { public TcpChannel channel(TransportRequestOptions.Type type) {
ConnectionProfile.ConnectionTypeHandle connectionTypeHandle = typeMapping.get(type); ConnectionProfile.ConnectionTypeHandle connectionTypeHandle = typeMapping.get(type);
if (connectionTypeHandle == null) { if (connectionTypeHandle == null) {
throw new IllegalArgumentException("no type channel for [" + type + "]"); throw new IllegalArgumentException("no type channel for [" + type + "]");
@ -477,7 +477,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
if (closed.get()) { if (closed.get()) {
throw new NodeNotConnectedException(node, "connection already closed"); throw new NodeNotConnectedException(node, "connection already closed");
} }
Channel channel = channel(options.type()); TcpChannel channel = channel(options.type());
sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), (byte) 0); sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), (byte) 0);
} }
@ -594,13 +594,13 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
try { try {
int numConnections = connectionProfile.getNumConnections(); int numConnections = connectionProfile.getNumConnections();
assert numConnections > 0 : "A connection profile must be configured with at least one connection"; assert numConnections > 0 : "A connection profile must be configured with at least one connection";
List<Channel> channels = new ArrayList<>(numConnections); List<TcpChannel> channels = new ArrayList<>(numConnections);
List<ActionFuture<Channel>> connectionFutures = new ArrayList<>(numConnections); List<ActionFuture<TcpChannel>> connectionFutures = new ArrayList<>(numConnections);
for (int i = 0; i < numConnections; ++i) { for (int i = 0; i < numConnections; ++i) {
try { try {
PlainActionFuture<Channel> connectFuture = PlainActionFuture.newFuture(); PlainActionFuture<TcpChannel> connectFuture = PlainActionFuture.newFuture();
connectionFutures.add(connectFuture); connectionFutures.add(connectFuture);
Channel channel = initiateChannel(node, connectionProfile.getConnectTimeout(), connectFuture); TcpChannel channel = initiateChannel(node, connectionProfile.getConnectTimeout(), connectFuture);
channels.add(channel); channels.add(channel);
} catch (Exception e) { } catch (Exception e) {
// If there was an exception when attempting to instantiate the raw channels, we close all of the channels // If there was an exception when attempting to instantiate the raw channels, we close all of the channels
@ -618,7 +618,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
} }
// If we make it past the block above, we have successfully established connections for all of the channels // If we make it past the block above, we have successfully established connections for all of the channels
final Channel handshakeChannel = channels.get(0); // one channel is guaranteed by the connection profile final TcpChannel handshakeChannel = channels.get(0); // one channel is guaranteed by the connection profile
handshakeChannel.addCloseListener(ActionListener.wrap(() -> cancelHandshakeForChannel(handshakeChannel))); handshakeChannel.addCloseListener(ActionListener.wrap(() -> cancelHandshakeForChannel(handshakeChannel)));
Version version; Version version;
try { try {
@ -635,7 +635,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
transportService.onConnectionOpened(nodeChannels); transportService.onConnectionOpened(nodeChannels);
final NodeChannels finalNodeChannels = nodeChannels; final NodeChannels finalNodeChannels = nodeChannels;
final AtomicBoolean runOnce = new AtomicBoolean(false); final AtomicBoolean runOnce = new AtomicBoolean(false);
Consumer<Channel> onClose = c -> { Consumer<TcpChannel> onClose = c -> {
assert c.isOpen() == false : "channel is still open when onClose is called"; assert c.isOpen() == false : "channel is still open when onClose is called";
// we only need to disconnect from the nodes once since all other channels // we only need to disconnect from the nodes once since all other channels
// will also try to run this we protect it from running multiple times. // will also try to run this we protect it from running multiple times.
@ -772,15 +772,15 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
final AtomicReference<InetSocketAddress> boundSocket = new AtomicReference<>(); final AtomicReference<InetSocketAddress> boundSocket = new AtomicReference<>();
boolean success = portsRange.iterate(portNumber -> { boolean success = portsRange.iterate(portNumber -> {
try { try {
Channel channel = bind(name, new InetSocketAddress(hostAddress, portNumber)); TcpChannel channel = bind(name, new InetSocketAddress(hostAddress, portNumber));
synchronized (serverChannels) { synchronized (serverChannels) {
List<Channel> list = serverChannels.get(name); List<TcpChannel> list = serverChannels.get(name);
if (list == null) { if (list == null) {
list = new ArrayList<>(); list = new ArrayList<>();
serverChannels.put(name, list); serverChannels.put(name, list);
} }
list.add(channel); list.add(channel);
boundSocket.set(getLocalAddress(channel)); boundSocket.set(channel.getLocalAddress());
} }
} catch (Exception e) { } catch (Exception e) {
lastException.set(e); lastException.set(e);
@ -937,9 +937,9 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
closeLock.writeLock().lock(); closeLock.writeLock().lock();
try { try {
// first stop to accept any incoming connections so nobody can connect to this transport // first stop to accept any incoming connections so nobody can connect to this transport
for (Map.Entry<String, List<Channel>> entry : serverChannels.entrySet()) { for (Map.Entry<String, List<TcpChannel>> entry : serverChannels.entrySet()) {
String profile = entry.getKey(); String profile = entry.getKey();
List<Channel> channels = entry.getValue(); List<TcpChannel> channels = entry.getValue();
ActionListener<TcpChannel> closeFailLogger = ActionListener.wrap(c -> {}, ActionListener<TcpChannel> closeFailLogger = ActionListener.wrap(c -> {},
e -> logger.warn(() -> new ParameterizedMessage("Error closing serverChannel for profile [{}]", profile), e)); e -> logger.warn(() -> new ParameterizedMessage("Error closing serverChannel for profile [{}]", profile), e));
channels.forEach(c -> c.addCloseListener(closeFailLogger)); channels.forEach(c -> c.addCloseListener(closeFailLogger));
@ -979,7 +979,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
} }
} }
protected void onException(Channel channel, Exception e) { protected void onException(TcpChannel channel, Exception e) {
if (!lifecycle.started()) { if (!lifecycle.started()) {
// just close and ignore - we are already stopped and just need to make sure we release all resources // just close and ignore - we are already stopped and just need to make sure we release all resources
TcpChannel.closeChannel(channel, false); TcpChannel.closeChannel(channel, false);
@ -1014,9 +1014,9 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
// in case we are able to return data, serialize the exception content and sent it back to the client // in case we are able to return data, serialize the exception content and sent it back to the client
if (channel.isOpen()) { if (channel.isOpen()) {
BytesArray message = new BytesArray(e.getMessage().getBytes(StandardCharsets.UTF_8)); BytesArray message = new BytesArray(e.getMessage().getBytes(StandardCharsets.UTF_8));
final SendMetricListener<Channel> closeChannel = new SendMetricListener<Channel>(message.length()) { final SendMetricListener closeChannel = new SendMetricListener(message.length()) {
@Override @Override
protected void innerInnerOnResponse(Channel channel) { protected void innerInnerOnResponse(TcpChannel channel) {
TcpChannel.closeChannel(channel, false); TcpChannel.closeChannel(channel, false);
} }
@ -1036,34 +1036,19 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
} }
} }
protected void serverAcceptedChannel(Channel channel) { protected void serverAcceptedChannel(TcpChannel channel) {
boolean addedOnThisCall = acceptedChannels.add(channel); boolean addedOnThisCall = acceptedChannels.add(channel);
assert addedOnThisCall : "Channel should only be added to accept channel set once"; assert addedOnThisCall : "Channel should only be added to accept channel set once";
channel.addCloseListener(ActionListener.wrap(() -> acceptedChannels.remove(channel))); channel.addCloseListener(ActionListener.wrap(() -> acceptedChannels.remove(channel)));
} }
/**
* Returns the channels local address
*/
protected abstract InetSocketAddress getLocalAddress(Channel channel);
/** /**
* Binds to the given {@link InetSocketAddress} * Binds to the given {@link InetSocketAddress}
* *
* @param name the profile name * @param name the profile name
* @param address the address to bind to * @param address the address to bind to
*/ */
protected abstract Channel bind(String name, InetSocketAddress address) throws IOException; protected abstract TcpChannel bind(String name, InetSocketAddress address) throws IOException;
/**
* Sends message to channel. The listener's onResponse method will be called when the send is complete unless an exception
* is thrown during the send. If an exception is thrown, the listener's onException method will be called.
*
* @param channel the destination channel
* @param reference the byte reference for the message
* @param listener the listener to call when the operation has completed
*/
protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener<Channel> listener);
/** /**
* Initiate a single tcp socket channel to a node. Implementations do not have to observe the connectTimeout. * Initiate a single tcp socket channel to a node. Implementations do not have to observe the connectTimeout.
@ -1075,7 +1060,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
* @return the pending connection * @return the pending connection
* @throws IOException if an I/O exception occurs while opening the channel * @throws IOException if an I/O exception occurs while opening the channel
*/ */
protected abstract Channel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Channel> connectListener) protected abstract TcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<TcpChannel> connectListener)
throws IOException; throws IOException;
/** /**
@ -1088,7 +1073,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
return compress && (!(request instanceof BytesTransportRequest)); return compress && (!(request instanceof BytesTransportRequest));
} }
private void sendRequestToChannel(final DiscoveryNode node, final Channel targetChannel, final long requestId, final String action, private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
final TransportRequest request, TransportRequestOptions options, Version channelVersion, final TransportRequest request, TransportRequestOptions options, Version channelVersion,
byte status) throws IOException, byte status) throws IOException,
TransportException { TransportException {
@ -1120,9 +1105,9 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
BytesReference message = buildMessage(requestId, status, node.getVersion(), request, stream); BytesReference message = buildMessage(requestId, status, node.getVersion(), request, stream);
final TransportRequestOptions finalOptions = options; final TransportRequestOptions finalOptions = options;
// this might be called in a different thread // this might be called in a different thread
SendListener onRequestSent = new SendListener(stream, SendListener onRequestSent = new SendListener(channel, stream,
() -> transportService.onRequestSent(node, requestId, action, request, finalOptions), message.length()); () -> transportService.onRequestSent(node, requestId, action, request, finalOptions), message.length());
internalSendMessage(targetChannel, message, onRequestSent); internalSendMessage(channel, message, onRequestSent);
addedReleaseListener = true; addedReleaseListener = true;
} finally { } finally {
if (!addedReleaseListener) { if (!addedReleaseListener) {
@ -1134,13 +1119,13 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
/** /**
* sends a message to the given channel, using the given callbacks. * sends a message to the given channel, using the given callbacks.
*/ */
private void internalSendMessage(Channel targetChannel, BytesReference message, SendMetricListener<Channel> listener) { private void internalSendMessage(TcpChannel channel, BytesReference message, SendMetricListener listener) {
try { try {
sendMessage(targetChannel, message, listener); channel.sendMessage(message, listener);
} catch (Exception ex) { } catch (Exception ex) {
// call listener to ensure that any resources are released // call listener to ensure that any resources are released
listener.onFailure(ex); listener.onFailure(ex);
onException(targetChannel, ex); onException(channel, ex);
} }
} }
@ -1153,12 +1138,12 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
* @param requestId the request ID this response replies to * @param requestId the request ID this response replies to
* @param action the action this response replies to * @param action the action this response replies to
*/ */
public void sendErrorResponse(Version nodeVersion, Channel channel, final Exception error, final long requestId, public void sendErrorResponse(Version nodeVersion, TcpChannel channel, final Exception error, final long requestId,
final String action) throws IOException { final String action) throws IOException {
try (BytesStreamOutput stream = new BytesStreamOutput()) { try (BytesStreamOutput stream = new BytesStreamOutput()) {
stream.setVersion(nodeVersion); stream.setVersion(nodeVersion);
RemoteTransportException tx = new RemoteTransportException( RemoteTransportException tx = new RemoteTransportException(
nodeName(), new TransportAddress(getLocalAddress(channel)), action, error); nodeName(), new TransportAddress(channel.getLocalAddress()), action, error);
threadPool.getThreadContext().writeTo(stream); threadPool.getThreadContext().writeTo(stream);
stream.writeException(tx); stream.writeException(tx);
byte status = 0; byte status = 0;
@ -1167,7 +1152,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
final BytesReference bytes = stream.bytes(); final BytesReference bytes = stream.bytes();
final BytesReference header = buildHeader(requestId, status, nodeVersion, bytes.length()); final BytesReference header = buildHeader(requestId, status, nodeVersion, bytes.length());
CompositeBytesReference message = new CompositeBytesReference(header, bytes); CompositeBytesReference message = new CompositeBytesReference(header, bytes);
SendListener onResponseSent = new SendListener(null, SendListener onResponseSent = new SendListener(channel, null,
() -> transportService.onResponseSent(requestId, action, error), message.length()); () -> transportService.onResponseSent(requestId, action, error), message.length());
internalSendMessage(channel, message, onResponseSent); internalSendMessage(channel, message, onResponseSent);
} }
@ -1178,12 +1163,12 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
* *
* @see #sendErrorResponse(Version, TcpChannel, Exception, long, String) for sending back errors to the caller * @see #sendErrorResponse(Version, TcpChannel, Exception, long, String) for sending back errors to the caller
*/ */
public void sendResponse(Version nodeVersion, Channel channel, final TransportResponse response, final long requestId, public void sendResponse(Version nodeVersion, TcpChannel channel, final TransportResponse response, final long requestId,
final String action, TransportResponseOptions options) throws IOException { final String action, TransportResponseOptions options) throws IOException {
sendResponse(nodeVersion, channel, response, requestId, action, options, (byte) 0); sendResponse(nodeVersion, channel, response, requestId, action, options, (byte) 0);
} }
private void sendResponse(Version nodeVersion, Channel channel, final TransportResponse response, final long requestId, private void sendResponse(Version nodeVersion, TcpChannel channel, final TransportResponse response, final long requestId,
final String action, TransportResponseOptions options, byte status) throws IOException { final String action, TransportResponseOptions options, byte status) throws IOException {
if (compress) { if (compress) {
options = TransportResponseOptions.builder(options).withCompress(true).build(); options = TransportResponseOptions.builder(options).withCompress(true).build();
@ -1202,7 +1187,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
final TransportResponseOptions finalOptions = options; final TransportResponseOptions finalOptions = options;
// this might be called in a different thread // this might be called in a different thread
SendListener listener = new SendListener(stream, SendListener listener = new SendListener(channel, stream,
() -> transportService.onResponseSent(requestId, action, response, finalOptions), message.length()); () -> transportService.onResponseSent(requestId, action, response, finalOptions), message.length());
internalSendMessage(channel, message, listener); internalSendMessage(channel, message, listener);
addedReleaseListener = true; addedReleaseListener = true;
@ -1355,7 +1340,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
/** /**
* This method handles the message receive part for both request and responses * This method handles the message receive part for both request and responses
*/ */
public final void messageReceived(BytesReference reference, Channel channel, String profileName, public final void messageReceived(BytesReference reference, TcpChannel channel, String profileName,
InetSocketAddress remoteAddress, int messageLengthBytes) throws IOException { InetSocketAddress remoteAddress, int messageLengthBytes) throws IOException {
final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
readBytesMetric.inc(totalMessageSize); readBytesMetric.inc(totalMessageSize);
@ -1494,8 +1479,9 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
}); });
} }
protected String handleRequest(Channel channel, String profileName, final StreamInput stream, long requestId, int messageLengthBytes, protected String handleRequest(TcpChannel channel, String profileName, final StreamInput stream, long requestId,
Version version, InetSocketAddress remoteAddress, byte status) throws IOException { int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status)
throws IOException {
final String action = stream.readString(); final String action = stream.readString();
transportService.onRequestReceived(requestId, action); transportService.onRequestReceived(requestId, action);
TransportChannel transportChannel = null; TransportChannel transportChannel = null;
@ -1514,7 +1500,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
} else { } else {
getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes); getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes);
} }
transportChannel = new TcpTransportChannel<>(this, channel, transportName, action, requestId, version, profileName, transportChannel = new TcpTransportChannel(this, channel, transportName, action, requestId, version, profileName,
messageLengthBytes); messageLengthBytes);
final TransportRequest request = reg.newRequest(stream); final TransportRequest request = reg.newRequest(stream);
request.remoteAddress(new TransportAddress(remoteAddress)); request.remoteAddress(new TransportAddress(remoteAddress));
@ -1525,7 +1511,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
} catch (Exception e) { } catch (Exception e) {
// the circuit breaker tripped // the circuit breaker tripped
if (transportChannel == null) { if (transportChannel == null) {
transportChannel = new TcpTransportChannel<>(this, channel, transportName, action, requestId, version, profileName, 0); transportChannel = new TcpTransportChannel(this, channel, transportName, action, requestId, version, profileName, 0);
} }
try { try {
transportChannel.sendResponse(e); transportChannel.sendResponse(e);
@ -1611,7 +1597,8 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
} }
} }
protected Version executeHandshake(DiscoveryNode node, Channel channel, TimeValue timeout) throws IOException, InterruptedException { protected Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout)
throws IOException, InterruptedException {
numHandshakes.inc(); numHandshakes.inc();
final long requestId = newRequestId(); final long requestId = newRequestId();
final HandshakeResponseHandler handler = new HandshakeResponseHandler(channel); final HandshakeResponseHandler handler = new HandshakeResponseHandler(channel);
@ -1671,7 +1658,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
/** /**
* Called once the channel is closed for instance due to a disconnect or a closed socket etc. * Called once the channel is closed for instance due to a disconnect or a closed socket etc.
*/ */
private void cancelHandshakeForChannel(Channel channel) { private void cancelHandshakeForChannel(TcpChannel channel) {
final Optional<Long> first = pendingHandshakes.entrySet().stream() final Optional<Long> first = pendingHandshakes.entrySet().stream()
.filter((entry) -> entry.getValue().channel == channel).map(Map.Entry::getKey).findFirst(); .filter((entry) -> entry.getValue().channel == channel).map(Map.Entry::getKey).findFirst();
if (first.isPresent()) { if (first.isPresent()) {
@ -1699,7 +1686,7 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
/** /**
* This listener increments the transmitted bytes metric on success. * This listener increments the transmitted bytes metric on success.
*/ */
private abstract class SendMetricListener<T> extends NotifyOnceListener<T> { private abstract class SendMetricListener extends NotifyOnceListener<TcpChannel> {
private final long messageSize; private final long messageSize;
private SendMetricListener(long messageSize) { private SendMetricListener(long messageSize) {
@ -1707,31 +1694,34 @@ public abstract class TcpTransport<Channel extends TcpChannel> extends AbstractL
} }
@Override @Override
protected final void innerOnResponse(T object) { protected final void innerOnResponse(org.elasticsearch.transport.TcpChannel object) {
transmittedBytesMetric.inc(messageSize); transmittedBytesMetric.inc(messageSize);
innerInnerOnResponse(object); innerInnerOnResponse(object);
} }
protected abstract void innerInnerOnResponse(T object); protected abstract void innerInnerOnResponse(org.elasticsearch.transport.TcpChannel object);
} }
private final class SendListener extends SendMetricListener<Channel> { private final class SendListener extends SendMetricListener {
private final TcpChannel channel;
private final Releasable optionalReleasable; private final Releasable optionalReleasable;
private final Runnable transportAdaptorCallback; private final Runnable transportAdaptorCallback;
private SendListener(Releasable optionalReleasable, Runnable transportAdaptorCallback, long messageLength) { private SendListener(TcpChannel channel, Releasable optionalReleasable, Runnable transportAdaptorCallback, long messageLength) {
super(messageLength); super(messageLength);
this.channel = channel;
this.optionalReleasable = optionalReleasable; this.optionalReleasable = optionalReleasable;
this.transportAdaptorCallback = transportAdaptorCallback; this.transportAdaptorCallback = transportAdaptorCallback;
} }
@Override @Override
protected void innerInnerOnResponse(Channel channel) { protected void innerInnerOnResponse(TcpChannel channel) {
release(); release();
} }
@Override @Override
protected void innerOnFailure(Exception e) { protected void innerOnFailure(Exception e) {
logger.warn(() -> new ParameterizedMessage("send message failed [channel: {}]", channel), e);
release(); release();
} }

View File

@ -23,8 +23,8 @@ import org.elasticsearch.Version;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
public final class TcpTransportChannel<Channel extends TcpChannel> implements TransportChannel { public final class TcpTransportChannel implements TransportChannel {
private final TcpTransport<Channel> transport; private final TcpTransport transport;
private final Version version; private final Version version;
private final String action; private final String action;
private final long requestId; private final long requestId;
@ -32,9 +32,9 @@ public final class TcpTransportChannel<Channel extends TcpChannel> implements Tr
private final long reservedBytes; private final long reservedBytes;
private final AtomicBoolean released = new AtomicBoolean(); private final AtomicBoolean released = new AtomicBoolean();
private final String channelType; private final String channelType;
private final Channel channel; private final TcpChannel channel;
TcpTransportChannel(TcpTransport<Channel> transport, Channel channel, String channelType, String action, TcpTransportChannel(TcpTransport transport, TcpChannel channel, String channelType, String action,
long requestId, Version version, String profileName, long reservedBytes) { long requestId, Version version, String profileName, long reservedBytes) {
this.version = version; this.version = version;
this.channel = channel; this.channel = channel;
@ -97,7 +97,7 @@ public final class TcpTransportChannel<Channel extends TcpChannel> implements Tr
return version; return version;
} }
public Channel getChannel() { public TcpChannel getChannel() {
return channel; return channel;
} }
} }

View File

@ -39,7 +39,6 @@ import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -172,57 +171,23 @@ public class TcpTransportTests extends ESTestCase {
public void testCompressRequest() throws IOException { public void testCompressRequest() throws IOException {
final boolean compressed = randomBoolean(); final boolean compressed = randomBoolean();
final AtomicBoolean called = new AtomicBoolean(false);
Req request = new Req(randomRealisticUnicodeOfLengthBetween(10, 100)); Req request = new Req(randomRealisticUnicodeOfLengthBetween(10, 100));
ThreadPool threadPool = new TestThreadPool(TcpTransportTests.class.getName()); ThreadPool threadPool = new TestThreadPool(TcpTransportTests.class.getName());
AtomicReference<IOException> exceptionReference = new AtomicReference<>(); AtomicReference<BytesReference> messageCaptor = new AtomicReference<>();
try { try {
TcpTransport<FakeChannel> transport = new TcpTransport<FakeChannel>( TcpTransport transport = new TcpTransport(
"test", Settings.builder().put("transport.tcp.compress", compressed).build(), threadPool, "test", Settings.builder().put("transport.tcp.compress", compressed).build(), threadPool,
new BigArrays(Settings.EMPTY, null), null, null, null) { new BigArrays(Settings.EMPTY, null), null, null, null) {
@Override
protected InetSocketAddress getLocalAddress(FakeChannel o) {
return null;
}
@Override @Override
protected FakeChannel bind(String name, InetSocketAddress address) throws IOException { protected FakeChannel bind(String name, InetSocketAddress address) throws IOException {
return null; return null;
} }
@Override
protected void sendMessage(FakeChannel o, BytesReference reference, ActionListener listener) {
try {
StreamInput streamIn = reference.streamInput();
streamIn.skip(TcpHeader.MARKER_BYTES_SIZE);
int len = streamIn.readInt();
long requestId = streamIn.readLong();
assertEquals(42, requestId);
byte status = streamIn.readByte();
Version version = Version.fromId(streamIn.readInt());
assertEquals(Version.CURRENT, version);
assertEquals(compressed, TransportStatus.isCompress(status));
called.compareAndSet(false, true);
if (compressed) {
final int bytesConsumed = TcpHeader.HEADER_SIZE;
streamIn = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed))
.streamInput(streamIn);
}
threadPool.getThreadContext().readHeaders(streamIn);
assertEquals("foobar", streamIn.readString());
Req readReq = new Req("");
readReq.readFrom(streamIn);
assertEquals(request.value, readReq.value);
} catch (IOException e) {
exceptionReference.set(e);
}
}
@Override @Override
protected FakeChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, protected FakeChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout,
ActionListener<FakeChannel> connectListener) throws IOException { ActionListener<TcpChannel> connectListener) throws IOException {
FakeChannel fakeChannel = new FakeChannel(); return new FakeChannel(messageCaptor);
return fakeChannel;
} }
@Override @Override
@ -233,18 +198,41 @@ public class TcpTransportTests extends ESTestCase {
@Override @Override
public NodeChannels getConnection(DiscoveryNode node) { public NodeChannels getConnection(DiscoveryNode node) {
int numConnections = MockTcpTransport.LIGHT_PROFILE.getNumConnections(); int numConnections = MockTcpTransport.LIGHT_PROFILE.getNumConnections();
ArrayList<FakeChannel> fakeChannels = new ArrayList<>(numConnections); ArrayList<TcpChannel> fakeChannels = new ArrayList<>(numConnections);
for (int i = 0; i < numConnections; ++i) { for (int i = 0; i < numConnections; ++i) {
fakeChannels.add(new FakeChannel()); fakeChannels.add(new FakeChannel(messageCaptor));
} }
return new NodeChannels(node, fakeChannels, MockTcpTransport.LIGHT_PROFILE, Version.CURRENT); return new NodeChannels(node, fakeChannels, MockTcpTransport.LIGHT_PROFILE, Version.CURRENT);
} }
}; };
DiscoveryNode node = new DiscoveryNode("foo", buildNewFakeTransportAddress(), Version.CURRENT); DiscoveryNode node = new DiscoveryNode("foo", buildNewFakeTransportAddress(), Version.CURRENT);
Transport.Connection connection = transport.getConnection(node); Transport.Connection connection = transport.getConnection(node);
connection.sendRequest(42, "foobar", request, TransportRequestOptions.EMPTY); connection.sendRequest(42, "foobar", request, TransportRequestOptions.EMPTY);
assertTrue(called.get());
assertNull("IOException while sending message.", exceptionReference.get()); BytesReference reference = messageCaptor.get();
assertNotNull(reference);
StreamInput streamIn = reference.streamInput();
streamIn.skip(TcpHeader.MARKER_BYTES_SIZE);
int len = streamIn.readInt();
long requestId = streamIn.readLong();
assertEquals(42, requestId);
byte status = streamIn.readByte();
Version version = Version.fromId(streamIn.readInt());
assertEquals(Version.CURRENT, version);
assertEquals(compressed, TransportStatus.isCompress(status));
if (compressed) {
final int bytesConsumed = TcpHeader.HEADER_SIZE;
streamIn = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed))
.streamInput(streamIn);
}
threadPool.getThreadContext().readHeaders(streamIn);
assertEquals("foobar", streamIn.readString());
Req readReq = new Req("");
readReq.readFrom(streamIn);
assertEquals(request.value, readReq.value);
} finally { } finally {
ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS);
} }
@ -252,6 +240,12 @@ public class TcpTransportTests extends ESTestCase {
private static final class FakeChannel implements TcpChannel { private static final class FakeChannel implements TcpChannel {
private final AtomicReference<BytesReference> messageCaptor;
FakeChannel(AtomicReference<BytesReference> messageCaptor) {
this.messageCaptor = messageCaptor;
}
@Override @Override
public void close() { public void close() {
} }
@ -268,6 +262,16 @@ public class TcpTransportTests extends ESTestCase {
public boolean isOpen() { public boolean isOpen() {
return false; return false;
} }
@Override
public InetSocketAddress getLocalAddress() {
return null;
}
@Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) {
messageCaptor.set(reference);
}
} }
private static final class Req extends TransportRequest { private static final class Req extends TransportRequest {

View File

@ -42,7 +42,6 @@ import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.lease.Releasables;
@ -57,6 +56,7 @@ import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportRequestOptions;
@ -79,7 +79,7 @@ import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadF
* longer. Med is for the typical search / single doc index. And High for things like cluster state. Ping is reserved for * longer. Med is for the typical search / single doc index. And High for things like cluster state. Ping is reserved for
* sending out ping requests to other nodes. * sending out ping requests to other nodes.
*/ */
public class Netty4Transport extends TcpTransport<NettyTcpChannel> { public class Netty4Transport extends TcpTransport {
static { static {
Netty4Utils.setup(); Netty4Utils.setup();
@ -249,7 +249,7 @@ public class Netty4Transport extends TcpTransport<NettyTcpChannel> {
} }
@Override @Override
protected NettyTcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<NettyTcpChannel> listener) protected NettyTcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<TcpChannel> listener)
throws IOException { throws IOException {
ChannelFuture channelFuture = bootstrap.connect(node.getAddress().address()); ChannelFuture channelFuture = bootstrap.connect(node.getAddress().address());
Channel channel = channelFuture.channel(); Channel channel = channelFuture.channel();
@ -279,28 +279,6 @@ public class Netty4Transport extends TcpTransport<NettyTcpChannel> {
return nettyChannel; return nettyChannel;
} }
@Override
protected void sendMessage(NettyTcpChannel channel, BytesReference reference, ActionListener<NettyTcpChannel> listener) {
final ChannelFuture future = channel.getLowLevelChannel().writeAndFlush(Netty4Utils.toByteBuf(reference));
future.addListener(f -> {
if (f.isSuccess()) {
listener.onResponse(channel);
} else {
final Throwable cause = f.cause();
Netty4Utils.maybeDie(cause);
logger.warn((Supplier<?>) () ->
new ParameterizedMessage("write and flush on the network layer failed (channel: {})", channel), cause);
assert cause instanceof Exception;
listener.onFailure((Exception) cause);
}
});
}
@Override
protected InetSocketAddress getLocalAddress(NettyTcpChannel channel) {
return (InetSocketAddress) channel.getLowLevelChannel().localAddress();
}
@Override @Override
protected NettyTcpChannel bind(String name, InetSocketAddress address) { protected NettyTcpChannel bind(String name, InetSocketAddress address) {
Channel channel = serverBootstraps.get(name).bind(address).syncUninterruptibly().channel(); Channel channel = serverBootstraps.get(name).bind(address).syncUninterruptibly().channel();

View File

@ -20,10 +20,15 @@
package org.elasticsearch.transport.netty4; package org.elasticsearch.transport.netty4;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOption;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpChannel;
import java.net.InetSocketAddress;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public class NettyTcpChannel implements TcpChannel { public class NettyTcpChannel implements TcpChannel {
@ -48,10 +53,6 @@ public class NettyTcpChannel implements TcpChannel {
}); });
} }
public Channel getLowLevelChannel() {
return channel;
}
@Override @Override
public void close() { public void close() {
channel.close(); channel.close();
@ -71,4 +72,28 @@ public class NettyTcpChannel implements TcpChannel {
public boolean isOpen() { public boolean isOpen() {
return channel.isOpen(); return channel.isOpen();
} }
@Override
public InetSocketAddress getLocalAddress() {
return (InetSocketAddress) channel.localAddress();
}
@Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) {
final ChannelFuture future = channel.writeAndFlush(Netty4Utils.toByteBuf(reference));
future.addListener(f -> {
if (f.isSuccess()) {
listener.onResponse(this);
} else {
final Throwable cause = f.cause();
Netty4Utils.maybeDie(cause);
assert cause instanceof Exception;
listener.onFailure((Exception) cause);
}
});
}
public Channel getLowLevelChannel() {
return channel;
}
} }

View File

@ -36,6 +36,7 @@ import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase.ClusterScope; import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
import org.elasticsearch.test.ESIntegTestCase.Scope; import org.elasticsearch.test.ESIntegTestCase.Scope;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
@ -108,7 +109,8 @@ public class Netty4TransportIT extends ESNetty4IntegTestCase {
super(settings, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); super(settings, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService);
} }
protected String handleRequest(NettyTcpChannel channel, String profileName, @Override
protected String handleRequest(TcpChannel channel, String profileName,
StreamInput stream, long requestId, int messageLengthBytes, Version version, StreamInput stream, long requestId, int messageLengthBytes, Version version,
InetSocketAddress remoteAddress, byte status) throws IOException { InetSocketAddress remoteAddress, byte status) throws IOException {
String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version, String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version,

View File

@ -30,7 +30,6 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportService;
import org.junit.Before; import org.junit.Before;
import java.util.Collections; import java.util.Collections;
@ -59,7 +58,7 @@ public class NettyTransportMultiPortTests extends ESTestCase {
.build(); .build();
ThreadPool threadPool = new TestThreadPool("tst"); ThreadPool threadPool = new TestThreadPool("tst");
try (TcpTransport<?> transport = startTransport(settings, threadPool)) { try (TcpTransport transport = startTransport(settings, threadPool)) {
assertEquals(1, transport.profileBoundAddresses().size()); assertEquals(1, transport.profileBoundAddresses().size());
assertEquals(1, transport.boundAddress().boundAddresses().length); assertEquals(1, transport.boundAddress().boundAddresses().length);
} finally { } finally {
@ -75,7 +74,7 @@ public class NettyTransportMultiPortTests extends ESTestCase {
.build(); .build();
ThreadPool threadPool = new TestThreadPool("tst"); ThreadPool threadPool = new TestThreadPool("tst");
try (TcpTransport<?> transport = startTransport(settings, threadPool)) { try (TcpTransport transport = startTransport(settings, threadPool)) {
assertEquals(1, transport.profileBoundAddresses().size()); assertEquals(1, transport.profileBoundAddresses().size());
assertEquals(1, transport.boundAddress().boundAddresses().length); assertEquals(1, transport.boundAddress().boundAddresses().length);
} finally { } finally {
@ -108,7 +107,7 @@ public class NettyTransportMultiPortTests extends ESTestCase {
.build(); .build();
ThreadPool threadPool = new TestThreadPool("tst"); ThreadPool threadPool = new TestThreadPool("tst");
try (TcpTransport<?> transport = startTransport(settings, threadPool)) { try (TcpTransport transport = startTransport(settings, threadPool)) {
assertEquals(0, transport.profileBoundAddresses().size()); assertEquals(0, transport.profileBoundAddresses().size());
assertEquals(1, transport.boundAddress().boundAddresses().length); assertEquals(1, transport.boundAddress().boundAddresses().length);
} finally { } finally {
@ -116,9 +115,9 @@ public class NettyTransportMultiPortTests extends ESTestCase {
} }
} }
private TcpTransport<?> startTransport(Settings settings, ThreadPool threadPool) { private TcpTransport startTransport(Settings settings, ThreadPool threadPool) {
BigArrays bigArrays = new MockBigArrays(Settings.EMPTY, new NoneCircuitBreakerService()); BigArrays bigArrays = new MockBigArrays(Settings.EMPTY, new NoneCircuitBreakerService());
TcpTransport<?> transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), TcpTransport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()),
bigArrays, new NamedWriteableRegistry(Collections.emptyList()), new NoneCircuitBreakerService()); bigArrays, new NamedWriteableRegistry(Collections.emptyList()), new NoneCircuitBreakerService());
transport.start(); transport.start();

View File

@ -58,7 +58,7 @@ public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase
BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) {
@Override @Override
protected Version executeHandshake(DiscoveryNode node, NettyTcpChannel channel, TimeValue timeout) throws IOException, protected Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException,
InterruptedException { InterruptedException {
if (doHandshake) { if (doHandshake) {
return super.executeHandshake(node, channel, timeout); return super.executeHandshake(node, channel, timeout);
@ -90,7 +90,7 @@ public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException { protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
final Netty4Transport t = (Netty4Transport) transport; final Netty4Transport t = (Netty4Transport) transport;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final TcpTransport<NettyTcpChannel>.NodeChannels channels = (TcpTransport<NettyTcpChannel>.NodeChannels) connection; final TcpTransport.NodeChannels channels = (TcpTransport.NodeChannels) connection;
TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true); TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true);
} }

View File

@ -1976,7 +1976,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
MockTcpTransport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, MockTcpTransport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE,
new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList())) { new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList())) {
@Override @Override
protected String handleRequest(MockChannel mockChannel, String profileName, StreamInput stream, long requestId, protected String handleRequest(TcpChannel mockChannel, String profileName, StreamInput stream, long requestId,
int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status) int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status)
throws IOException { throws IOException {
return super.handleRequest(mockChannel, profileName, stream, requestId, messageLengthBytes, version, remoteAddress, return super.handleRequest(mockChannel, profileName, stream, requestId, messageLengthBytes, version, remoteAddress,

View File

@ -68,7 +68,7 @@ import java.util.function.Consumer;
* that need real networking. This implementation is a test only implementation that implements * that need real networking. This implementation is a test only implementation that implements
* the networking layer in the worst possible way since it blocks and uses a thread per request model. * the networking layer in the worst possible way since it blocks and uses a thread per request model.
*/ */
public class MockTcpTransport extends TcpTransport<MockTcpTransport.MockChannel> { public class MockTcpTransport extends TcpTransport {
/** /**
* A pre-built light connection profile that shares a single connection across all * A pre-built light connection profile that shares a single connection across all
@ -109,11 +109,6 @@ public class MockTcpTransport extends TcpTransport<MockTcpTransport.MockChannel>
this.mockVersion = mockVersion; this.mockVersion = mockVersion;
} }
@Override
protected InetSocketAddress getLocalAddress(MockChannel mockChannel) {
return mockChannel.localAddress;
}
@Override @Override
protected MockChannel bind(final String name, InetSocketAddress address) throws IOException { protected MockChannel bind(final String name, InetSocketAddress address) throws IOException {
MockServerSocket socket = new MockServerSocket(); MockServerSocket socket = new MockServerSocket();
@ -176,7 +171,7 @@ public class MockTcpTransport extends TcpTransport<MockTcpTransport.MockChannel>
} }
@Override @Override
protected MockChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<MockChannel> connectListener) protected MockChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<TcpChannel> connectListener)
throws IOException { throws IOException {
InetSocketAddress address = node.getAddress().address(); InetSocketAddress address = node.getAddress().address();
final MockSocket socket = new MockSocket(); final MockSocket socket = new MockSocket();
@ -222,22 +217,6 @@ public class MockTcpTransport extends TcpTransport<MockTcpTransport.MockChannel>
socket.setReuseAddress(TCP_REUSE_ADDRESS.get(settings)); socket.setReuseAddress(TCP_REUSE_ADDRESS.get(settings));
} }
@Override
protected void sendMessage(MockChannel mockChannel, BytesReference reference, ActionListener<MockChannel> listener) {
try {
synchronized (mockChannel) {
final Socket socket = mockChannel.activeChannel;
OutputStream outputStream = new BufferedOutputStream(socket.getOutputStream());
reference.writeTo(outputStream);
outputStream.flush();
}
listener.onResponse(mockChannel);
} catch (IOException e) {
listener.onFailure(e);
onException(mockChannel, e);
}
}
@Override @Override
public long getNumOpenServerConnections() { public long getNumOpenServerConnections() {
return 1; return 1;
@ -401,6 +380,25 @@ public class MockTcpTransport extends TcpTransport<MockTcpTransport.MockChannel>
return isOpen.get(); return isOpen.get();
} }
@Override
public InetSocketAddress getLocalAddress() {
return localAddress;
}
@Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) {
try {
synchronized (this) {
OutputStream outputStream = new BufferedOutputStream(activeChannel.getOutputStream());
reference.writeTo(outputStream);
outputStream.flush();
}
listener.onResponse(this);
} catch (IOException e) {
listener.onFailure(e);
onException(this, e);
}
}
} }

View File

@ -23,7 +23,6 @@ import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting;
@ -33,6 +32,7 @@ import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transports; import org.elasticsearch.transport.Transports;
import org.elasticsearch.transport.nio.channel.ChannelFactory; import org.elasticsearch.transport.nio.channel.ChannelFactory;
@ -54,7 +54,7 @@ import static org.elasticsearch.common.settings.Setting.intSetting;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory; import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory;
public class NioTransport extends TcpTransport<NioChannel> { public class NioTransport extends TcpTransport {
public static final String TRANSPORT_WORKER_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_WORKER_THREAD_NAME_PREFIX; public static final String TRANSPORT_WORKER_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_WORKER_THREAD_NAME_PREFIX;
public static final String TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX; public static final String TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX = Transports.NIO_TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX;
@ -87,11 +87,6 @@ public class NioTransport extends TcpTransport<NioChannel> {
return openChannels.serverChannelsCount(); return openChannels.serverChannelsCount();
} }
@Override
protected InetSocketAddress getLocalAddress(NioChannel channel) {
return channel.getLocalAddress();
}
@Override @Override
protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException { protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException {
ChannelFactory channelFactory = this.profileToChannelFactory.get(name); ChannelFactory channelFactory = this.profileToChannelFactory.get(name);
@ -100,21 +95,22 @@ public class NioTransport extends TcpTransport<NioChannel> {
} }
@Override @Override
protected void sendMessage(NioChannel channel, BytesReference reference, ActionListener<NioChannel> listener) { protected NioChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<TcpChannel> connectListener)
if (channel instanceof NioSocketChannel) {
NioSocketChannel nioSocketChannel = (NioSocketChannel) channel;
nioSocketChannel.getWriteContext().sendMessage(reference, listener);
} else {
logger.error("cannot send message to channel of this type [{}]", channel.getClass());
}
}
@Override
protected NioChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<NioChannel> connectListener)
throws IOException { throws IOException {
NioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get()); NioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get());
openChannels.clientChannelOpened(channel); openChannels.clientChannelOpened(channel);
channel.addConnectListener(connectListener); // TODO: Temporary conversion due to types
channel.addConnectListener(new ActionListener<NioChannel>() {
@Override
public void onResponse(NioChannel nioChannel) {
connectListener.onResponse(nioChannel);
}
@Override
public void onFailure(Exception e) {
connectListener.onFailure(e);
}
});
return channel; return channel;
} }

View File

@ -19,6 +19,9 @@
package org.elasticsearch.transport.nio.channel; package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.AcceptingSelector;
import java.io.IOException; import java.io.IOException;
@ -39,6 +42,11 @@ public class NioServerSocketChannel extends AbstractNioChannel<ServerSocketChann
return channelFactory; return channelFactory;
} }
@Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) {
throw new UnsupportedOperationException("Cannot send a message to a server channel.");
}
@Override @Override
public String toString() { public String toString() {
return "NioServerSocketChannel{" + return "NioServerSocketChannel{" +

View File

@ -20,6 +20,8 @@
package org.elasticsearch.transport.nio.channel; package org.elasticsearch.transport.nio.channel;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.nio.NetworkBytesReference; import org.elasticsearch.transport.nio.NetworkBytesReference;
import org.elasticsearch.transport.nio.SocketSelector; import org.elasticsearch.transport.nio.SocketSelector;
@ -46,6 +48,22 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
this.socketSelector = selector; this.socketSelector = selector;
} }
@Override
public void sendMessage(BytesReference reference, ActionListener<TcpChannel> listener) {
// TODO: Temporary conversion due to types
writeContext.sendMessage(reference, new ActionListener<NioChannel>() {
@Override
public void onResponse(NioChannel nioChannel) {
listener.onResponse(nioChannel);
}
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
});
}
@Override @Override
public void closeFromSelector() throws IOException { public void closeFromSelector() throws IOException {
assert socketSelector.isOnCurrentThread() : "Should only call from selector thread"; assert socketSelector.isOnCurrentThread() : "Should only call from selector thread";

View File

@ -40,7 +40,7 @@ public class MockTcpTransportTests extends AbstractSimpleTransportTestCase {
Transport transport = new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, Transport transport = new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE,
new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version) { new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version) {
@Override @Override
protected Version executeHandshake(DiscoveryNode node, MockChannel mockChannel, TimeValue timeout) throws IOException, protected Version executeHandshake(DiscoveryNode node, TcpChannel mockChannel, TimeValue timeout) throws IOException,
InterruptedException { InterruptedException {
if (doHandshake) { if (doHandshake) {
return super.executeHandshake(node, mockChannel, timeout); return super.executeHandshake(node, mockChannel, timeout);
@ -58,8 +58,8 @@ public class MockTcpTransportTests extends AbstractSimpleTransportTestCase {
@Override @Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException { protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
final MockTcpTransport t = (MockTcpTransport) transport; final MockTcpTransport t = (MockTcpTransport) transport;
@SuppressWarnings("unchecked") final TcpTransport<MockTcpTransport.MockChannel>.NodeChannels channels = @SuppressWarnings("unchecked") final TcpTransport.NodeChannels channels =
(TcpTransport<MockTcpTransport.MockChannel>.NodeChannels) connection; (TcpTransport.NodeChannels) connection;
TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true); TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true);
} }

View File

@ -39,7 +39,6 @@ import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.nio.channel.NioChannel;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
@ -62,7 +61,7 @@ public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase {
BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) {
@Override @Override
protected Version executeHandshake(DiscoveryNode node, NioChannel channel, TimeValue timeout) throws IOException, protected Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException,
InterruptedException { InterruptedException {
if (doHandshake) { if (doHandshake) {
return super.executeHandshake(node, channel, timeout); return super.executeHandshake(node, channel, timeout);
@ -100,7 +99,7 @@ public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase {
@Override @Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException { protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
TcpTransport<NioChannel>.NodeChannels channels = (TcpTransport<NioChannel>.NodeChannels) connection; TcpTransport.NodeChannels channels = (TcpTransport.NodeChannels) connection;
TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true); TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true);
} }