Clean up the transport implementation and reduce duplication.
This commit is contained in:
Timothy Bish 2016-07-18 16:54:37 -04:00
parent 34d7b0bfcb
commit 4b018b4206
8 changed files with 195 additions and 459 deletions

View File

@ -1,4 +1,4 @@
/** /*
* Licensed to the Apache Software Foundation (ASF) under one or more * Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with * contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership. * this work for additional information regarding copyright ownership.
@ -33,8 +33,10 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.FixedRecvByteBufAllocator; import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
@ -58,15 +60,13 @@ public class NettyTcpTransport implements NettyTransport {
protected EventLoopGroup group; protected EventLoopGroup group;
protected Channel channel; protected Channel channel;
protected NettyTransportListener listener; protected NettyTransportListener listener;
protected NettyTransportOptions options; protected final NettyTransportOptions options;
protected final URI remote; protected final URI remote;
protected boolean secure;
private final AtomicBoolean connected = new AtomicBoolean(); private final AtomicBoolean connected = new AtomicBoolean();
private final AtomicBoolean closed = new AtomicBoolean(); private final AtomicBoolean closed = new AtomicBoolean();
private final CountDownLatch connectLatch = new CountDownLatch(1); private final CountDownLatch connectLatch = new CountDownLatch(1);
private IOException failureCause; private IOException failureCause;
private Throwable pendingFailure;
/** /**
* Create a new transport instance * Create a new transport instance
@ -91,10 +91,17 @@ public class NettyTcpTransport implements NettyTransport {
* the transport options used to configure the socket connection. * the transport options used to configure the socket connection.
*/ */
public NettyTcpTransport(NettyTransportListener listener, URI remoteLocation, NettyTransportOptions options) { public NettyTcpTransport(NettyTransportListener listener, URI remoteLocation, NettyTransportOptions options) {
if (options == null) {
throw new IllegalArgumentException("Transport Options cannot be null");
}
if (remoteLocation == null) {
throw new IllegalArgumentException("Transport remote location cannot be null");
}
this.options = options; this.options = options;
this.listener = listener; this.listener = listener;
this.remote = remoteLocation; this.remote = remoteLocation;
this.secure = remoteLocation.getScheme().equalsIgnoreCase("ssl");
} }
@Override @Override
@ -104,16 +111,27 @@ public class NettyTcpTransport implements NettyTransport {
throw new IllegalStateException("A transport listener must be set before connection attempts."); throw new IllegalStateException("A transport listener must be set before connection attempts.");
} }
final SslHandler sslHandler;
if (isSSL()) {
try {
sslHandler = NettyTransportSupport.createSslHandler(getRemoteLocation(), getSslOptions());
} catch (Exception ex) {
// TODO: can we stop it throwing Exception?
throw IOExceptionSupport.create(ex);
}
} else {
sslHandler = null;
}
group = new NioEventLoopGroup(1); group = new NioEventLoopGroup(1);
bootstrap = new Bootstrap(); bootstrap = new Bootstrap();
bootstrap.group(group); bootstrap.group(group);
bootstrap.channel(NioSocketChannel.class); bootstrap.channel(NioSocketChannel.class);
bootstrap.handler(new ChannelInitializer<Channel>() { bootstrap.handler(new ChannelInitializer<Channel>() {
@Override @Override
public void initChannel(Channel connectedChannel) throws Exception { public void initChannel(Channel connectedChannel) throws Exception {
configureChannel(connectedChannel); configureChannel(connectedChannel, sslHandler);
} }
}); });
@ -124,12 +142,8 @@ public class NettyTcpTransport implements NettyTransport {
@Override @Override
public void operationComplete(ChannelFuture future) throws Exception { public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) { if (!future.isSuccess()) {
handleConnected(future.channel()); handleException(future.channel(), IOExceptionSupport.create(future.cause()));
} else if (future.isCancelled()) {
connectionFailed(future.channel(), new IOException("Connection attempt was cancelled"));
} else {
connectionFailed(future.channel(), IOExceptionSupport.create(future.cause()));
} }
} }
}); });
@ -160,8 +174,8 @@ public class NettyTcpTransport implements NettyTransport {
@Override @Override
public void run() { public void run() {
if (pendingFailure != null) { if (failureCause != null) {
channel.pipeline().fireExceptionCaught(pendingFailure); channel.pipeline().fireExceptionCaught(failureCause);
} }
} }
}); });
@ -175,7 +189,7 @@ public class NettyTcpTransport implements NettyTransport {
@Override @Override
public boolean isSSL() { public boolean isSSL() {
return secure; return options.isSSL();
} }
@Override @Override
@ -222,14 +236,6 @@ public class NettyTcpTransport implements NettyTransport {
@Override @Override
public NettyTransportOptions getTransportOptions() { public NettyTransportOptions getTransportOptions() {
if (options == null) {
if (isSSL()) {
options = NettyTransportSslOptions.INSTANCE;
} else {
options = NettyTransportOptions.INSTANCE;
}
}
return options; return options;
} }
@ -240,36 +246,106 @@ public class NettyTcpTransport implements NettyTransport {
@Override @Override
public Principal getLocalPrincipal() { public Principal getLocalPrincipal() {
if (!isSSL()) { Principal result = null;
throw new UnsupportedOperationException("Not connected to a secure channel");
}
if (isSSL()) {
SslHandler sslHandler = channel.pipeline().get(SslHandler.class); SslHandler sslHandler = channel.pipeline().get(SslHandler.class);
result = sslHandler.engine().getSession().getLocalPrincipal();
return sslHandler.engine().getSession().getLocalPrincipal();
} }
//----- Internal implementation details, can be overridden as needed --// return result;
}
//----- Internal implementation details, can be overridden as needed -----//
protected String getRemoteHost() { protected String getRemoteHost() {
return remote.getHost(); return remote.getHost();
} }
protected int getRemotePort() { protected int getRemotePort() {
int port = remote.getPort(); if (remote.getPort() != -1) {
return remote.getPort();
if (port <= 0) {
if (isSSL()) {
port = getSslOptions().getDefaultSslPort();
} else { } else {
port = getTransportOptions().getDefaultTcpPort(); return isSSL() ? getSslOptions().getDefaultSslPort() : getTransportOptions().getDefaultTcpPort();
} }
} }
return port; protected void addAdditionalHandlers(ChannelPipeline pipeline) {
} }
protected void configureNetty(Bootstrap bootstrap, NettyTransportOptions options) { protected ChannelInboundHandlerAdapter createChannelHandler() {
return new NettyTcpTransportHandler();
}
//----- Event Handlers which can be overridden in subclasses -------------//
protected void handleConnected(Channel channel) throws Exception {
LOG.trace("Channel has become active! Channel is {}", channel);
connectionEstablished(channel);
}
protected void handleChannelInactive(Channel channel) throws Exception {
LOG.trace("Channel has gone inactive! Channel is {}", channel);
if (connected.compareAndSet(true, false) && !closed.get()) {
LOG.trace("Firing onTransportClosed listener");
listener.onTransportClosed();
}
}
protected void handleException(Channel channel, Throwable cause) throws Exception {
LOG.trace("Exception on channel! Channel is {}", channel);
if (connected.compareAndSet(true, false) && !closed.get()) {
LOG.trace("Firing onTransportError listener");
if (failureCause != null) {
listener.onTransportError(failureCause);
} else {
listener.onTransportError(cause);
}
} else {
// Hold the first failure for later dispatch if connect succeeds.
// This will then trigger disconnect using the first error reported.
if (failureCause == null) {
LOG.trace("Holding error until connect succeeds: {}", cause.getMessage());
failureCause = IOExceptionSupport.create(cause);
}
connectionFailed(channel, failureCause);
}
}
//----- State change handlers and checks ---------------------------------//
protected final void checkConnected() throws IOException {
if (!connected.get()) {
throw new IOException("Cannot send to a non-connected transport.");
}
}
/*
* Called when the transport has successfully connected and is ready for use.
*/
private void connectionEstablished(Channel connectedChannel) {
channel = connectedChannel;
connected.set(true);
connectLatch.countDown();
}
/*
* Called when the transport connection failed and an error should be returned.
*/
private void connectionFailed(Channel failedChannel, IOException cause) {
failureCause = cause;
channel = failedChannel;
connected.set(false);
connectLatch.countDown();
}
private NettyTransportSslOptions getSslOptions() {
return (NettyTransportSslOptions) getTransportOptions();
}
private void configureNetty(Bootstrap bootstrap, NettyTransportOptions options) {
bootstrap.option(ChannelOption.TCP_NODELAY, options.isTcpNoDelay()); bootstrap.option(ChannelOption.TCP_NODELAY, options.isTcpNoDelay());
bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, options.getConnectTimeout()); bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, options.getConnectTimeout());
bootstrap.option(ChannelOption.SO_KEEPALIVE, options.isTcpKeepAlive()); bootstrap.option(ChannelOption.SO_KEEPALIVE, options.isTcpKeepAlive());
@ -290,108 +366,63 @@ public class NettyTcpTransport implements NettyTransport {
} }
} }
protected void configureChannel(final Channel channel) throws Exception { private void configureChannel(final Channel channel, final SslHandler sslHandler) throws Exception {
if (isSSL()) { if (isSSL()) {
SslHandler sslHandler = NettyTransportSupport.createSslHandler(getRemoteLocation(), getSslOptions()); channel.pipeline().addLast(sslHandler);
}
addAdditionalHandlers(channel.pipeline());
channel.pipeline().addLast(createChannelHandler());
}
//----- Handle connection events -----------------------------------------//
protected abstract class NettyDefaultHandler<E> extends SimpleChannelInboundHandler<E> {
@Override
public void channelRegistered(ChannelHandlerContext context) throws Exception {
channel = context.channel();
}
@Override
public void channelActive(ChannelHandlerContext context) throws Exception {
// In the Secure case we need to let the handshake complete before we
// trigger the connected event.
if (!isSSL()) {
handleConnected(context.channel());
} else {
SslHandler sslHandler = context.pipeline().get(SslHandler.class);
sslHandler.handshakeFuture().addListener(new GenericFutureListener<Future<Channel>>() { sslHandler.handshakeFuture().addListener(new GenericFutureListener<Future<Channel>>() {
@Override @Override
public void operationComplete(Future<Channel> future) throws Exception { public void operationComplete(Future<Channel> future) throws Exception {
if (future.isSuccess()) { if (future.isSuccess()) {
LOG.trace("SSL Handshake has completed: {}", channel); LOG.trace("SSL Handshake has completed: {}", channel);
connectionEstablished(channel); handleConnected(channel);
} else { } else {
LOG.trace("SSL Handshake has failed: {}", channel); LOG.trace("SSL Handshake has failed: {}", channel);
connectionFailed(channel, IOExceptionSupport.create(future.cause())); handleException(channel, future.cause());
} }
} }
}); });
channel.pipeline().addLast(sslHandler);
} }
channel.pipeline().addLast(new NettyTcpTransportHandler());
}
protected void handleConnected(final Channel channel) throws Exception {
if (!isSSL()) {
connectionEstablished(channel);
}
}
//----- State change handlers and checks ---------------------------------//
/**
* Called when the transport has successfully connected and is ready for use.
*/
protected void connectionEstablished(Channel connectedChannel) {
channel = connectedChannel;
connected.set(true);
connectLatch.countDown();
}
/**
* Called when the transport connection failed and an error should be returned.
*
* @param failedChannel
* The Channel instance that failed.
* @param cause
* An IOException that describes the cause of the failed connection.
*/
protected void connectionFailed(Channel failedChannel, IOException cause) {
failureCause = IOExceptionSupport.create(cause);
channel = failedChannel;
connected.set(false);
connectLatch.countDown();
}
private NettyTransportSslOptions getSslOptions() {
return (NettyTransportSslOptions) getTransportOptions();
}
private void checkConnected() throws IOException {
if (!connected.get()) {
throw new IOException("Cannot send to a non-connected transport.");
}
}
//----- Handle connection events -----------------------------------------//
private class NettyTcpTransportHandler extends SimpleChannelInboundHandler<ByteBuf> {
@Override
public void channelActive(ChannelHandlerContext context) throws Exception {
LOG.trace("Channel has become active! Channel is {}", context.channel());
} }
@Override @Override
public void channelInactive(ChannelHandlerContext context) throws Exception { public void channelInactive(ChannelHandlerContext context) throws Exception {
LOG.trace("Channel has gone inactive! Channel is {}", context.channel()); handleChannelInactive(context.channel());
if (connected.compareAndSet(true, false) && !closed.get()) {
LOG.trace("Firing onTransportClosed listener");
listener.onTransportClosed();
}
} }
@Override @Override
public void exceptionCaught(ChannelHandlerContext context, Throwable cause) throws Exception { public void exceptionCaught(ChannelHandlerContext context, Throwable cause) throws Exception {
LOG.trace("Exception on channel! Channel is {}", context.channel()); handleException(context.channel(), cause);
if (connected.compareAndSet(true, false) && !closed.get()) {
LOG.trace("Firing onTransportError listener");
if (pendingFailure != null) {
listener.onTransportError(pendingFailure);
} else {
listener.onTransportError(cause);
}
} else {
// Hold the first failure for later dispatch if connect succeeds.
// This will then trigger disconnect using the first error reported.
if (pendingFailure != null) {
LOG.trace("Holding error until connect succeeds: {}", cause.getMessage());
pendingFailure = cause;
}
} }
} }
//----- Handle Binary data from connection -------------------------------//
protected class NettyTcpTransportHandler extends NettyDefaultHandler<ByteBuf> {
@Override @Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception { protected void channelRead0(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
LOG.trace("New data read: {} bytes incoming: {}", buffer.readableBytes(), buffer); LOG.trace("New data read: {} bytes incoming: {}", buffer.readableBytes(), buffer);

View File

@ -23,7 +23,7 @@ import java.security.Principal;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
/** /**
* * Base for all Netty based Transports in this client.
*/ */
public interface NettyTransport { public interface NettyTransport {

View File

@ -1,4 +1,4 @@
/** /*
* Licensed to the Apache Software Foundation (ASF) under one or more * Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with * contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership. * this work for additional information regarding copyright ownership.

View File

@ -1,4 +1,4 @@
/** /*
* Licensed to the Apache Software Foundation (ASF) under one or more * Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with * contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership. * this work for additional information regarding copyright ownership.

View File

@ -1,4 +1,4 @@
/** /*
* Licensed to the Apache Software Foundation (ASF) under one or more * Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with * contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership. * this work for additional information regarding copyright ownership.
@ -163,6 +163,10 @@ public class NettyTransportOptions implements Cloneable {
this.defaultTcpPort = defaultTcpPort; this.defaultTcpPort = defaultTcpPort;
} }
public boolean isSSL() {
return false;
}
@Override @Override
public NettyTransportOptions clone() { public NettyTransportOptions clone() {
return copyOptions(new NettyTransportOptions()); return copyOptions(new NettyTransportOptions());

View File

@ -1,4 +1,4 @@
/** /*
* Licensed to the Apache Software Foundation (ASF) under one or more * Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with * contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership. * this work for additional information regarding copyright ownership.
@ -261,6 +261,11 @@ public class NettyTransportSslOptions extends NettyTransportOptions {
this.defaultSslPort = defaultSslPort; this.defaultSslPort = defaultSslPort;
} }
@Override
public boolean isSSL() {
return true;
}
@Override @Override
public NettyTransportSslOptions clone() { public NettyTransportSslOptions clone() {
return copyOptions(new NettyTransportSslOptions()); return copyOptions(new NettyTransportSslOptions());

View File

@ -1,4 +1,4 @@
/** /*
* Licensed to the Apache Software Foundation (ASF) under one or more * Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with * contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership. * this work for additional information regarding copyright ownership.
@ -16,8 +16,6 @@
*/ */
package org.apache.activemq.transport.amqp.client.transport; package org.apache.activemq.transport.amqp.client.transport;
import io.netty.handler.ssl.SslHandler;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.InputStream; import java.io.InputStream;
@ -44,6 +42,8 @@ import javax.net.ssl.X509TrustManager;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import io.netty.handler.ssl.SslHandler;
/** /**
* Static class that provides various utility methods used by Transport implementations. * Static class that provides various utility methods used by Transport implementations.
*/ */

View File

@ -18,68 +18,38 @@ package org.apache.activemq.transport.amqp.client.transport;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.security.Principal; import java.nio.charset.StandardCharsets;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.activemq.transport.amqp.client.util.IOExceptionSupport;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory; import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketVersion; import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
/** /**
* Transport for communicating over WebSockets * Transport for communicating over WebSockets
*/ */
public class NettyWSTransport implements NettyTransport { public class NettyWSTransport extends NettyTcpTransport {
private static final Logger LOG = LoggerFactory.getLogger(NettyWSTransport.class); private static final Logger LOG = LoggerFactory.getLogger(NettyWSTransport.class);
private static final int QUIET_PERIOD = 20; private static final String AMQP_SUB_PROTOCOL = "amqp";
private static final int SHUTDOWN_TIMEOUT = 100;
protected Bootstrap bootstrap;
protected EventLoopGroup group;
protected Channel channel;
protected NettyTransportListener listener;
protected NettyTransportOptions options;
protected final URI remote;
protected boolean secure;
private final AtomicBoolean connected = new AtomicBoolean();
private final AtomicBoolean closed = new AtomicBoolean();
private ChannelPromise handshakeFuture;
private IOException failureCause;
private Throwable pendingFailure;
/** /**
* Create a new transport instance * Create a new transport instance
@ -104,114 +74,7 @@ public class NettyWSTransport implements NettyTransport {
* the transport options used to configure the socket connection. * the transport options used to configure the socket connection.
*/ */
public NettyWSTransport(NettyTransportListener listener, URI remoteLocation, NettyTransportOptions options) { public NettyWSTransport(NettyTransportListener listener, URI remoteLocation, NettyTransportOptions options) {
this.options = options; super(listener, remoteLocation, options);
this.listener = listener;
this.remote = remoteLocation;
this.secure = remoteLocation.getScheme().equalsIgnoreCase("wss");
}
@Override
public void connect() throws IOException {
if (listener == null) {
throw new IllegalStateException("A transport listener must be set before connection attempts.");
}
group = new NioEventLoopGroup(1);
bootstrap = new Bootstrap();
bootstrap.group(group);
bootstrap.channel(NioSocketChannel.class);
bootstrap.handler(new ChannelInitializer<Channel>() {
@Override
public void initChannel(Channel connectedChannel) throws Exception {
configureChannel(connectedChannel);
}
});
configureNetty(bootstrap, getTransportOptions());
ChannelFuture future;
try {
future = bootstrap.connect(getRemoteHost(), getRemotePort());
future.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
handleConnected(future.channel());
} else if (future.isCancelled()) {
connectionFailed(future.channel(), new IOException("Connection attempt was cancelled"));
} else {
connectionFailed(future.channel(), IOExceptionSupport.create(future.cause()));
}
}
});
future.sync();
// Now wait for WS protocol level handshake completion
handshakeFuture.await();
} catch (InterruptedException ex) {
LOG.debug("Transport connection attempt was interrupted.");
Thread.interrupted();
failureCause = IOExceptionSupport.create(ex);
}
if (failureCause != null) {
// Close out any Netty resources now as they are no longer needed.
if (channel != null) {
channel.close().syncUninterruptibly();
channel = null;
}
if (group != null) {
group.shutdownGracefully(QUIET_PERIOD, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS);
group = null;
}
throw failureCause;
} else {
// Connected, allow any held async error to fire now and close the transport.
channel.eventLoop().execute(new Runnable() {
@Override
public void run() {
if (pendingFailure != null) {
channel.pipeline().fireExceptionCaught(pendingFailure);
}
}
});
}
}
@Override
public boolean isConnected() {
return connected.get();
}
@Override
public boolean isSSL() {
return secure;
}
@Override
public void close() throws IOException {
if (closed.compareAndSet(false, true)) {
connected.set(false);
if (channel != null) {
channel.close().syncUninterruptibly();
}
if (group != null) {
group.shutdownGracefully(QUIET_PERIOD, SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS);
}
}
}
@Override
public ByteBuf allocateSendBuffer(int size) throws IOException {
checkConnected();
return channel.alloc().ioBuffer(size, size);
} }
@Override @Override
@ -228,206 +91,37 @@ public class NettyWSTransport implements NettyTransport {
} }
@Override @Override
public NettyTransportListener getTransportListener() { protected ChannelInboundHandlerAdapter createChannelHandler() {
return listener; return new NettyWebSocketTransportHandler();
} }
@Override @Override
public void setTransportListener(NettyTransportListener listener) { protected void addAdditionalHandlers(ChannelPipeline pipeline) {
this.listener = listener; pipeline.addLast(new HttpClientCodec());
pipeline.addLast(new HttpObjectAggregator(8192));
} }
@Override @Override
public NettyTransportOptions getTransportOptions() { protected void handleConnected(Channel channel) throws Exception {
if (options == null) { LOG.trace("Channel has become active, awaiting WebSocket handshake! Channel is {}", channel);
if (isSSL()) {
options = NettyTransportSslOptions.INSTANCE;
} else {
options = NettyTransportOptions.INSTANCE;
}
}
return options;
}
@Override
public URI getRemoteLocation() {
return remote;
}
@Override
public Principal getLocalPrincipal() {
if (!isSSL()) {
throw new UnsupportedOperationException("Not connected to a secure channel");
}
SslHandler sslHandler = channel.pipeline().get(SslHandler.class);
return sslHandler.engine().getSession().getLocalPrincipal();
}
//----- Internal implementation details, can be overridden as needed --//
protected String getRemoteHost() {
return remote.getHost();
}
protected int getRemotePort() {
int port = remote.getPort();
if (port <= 0) {
if (isSSL()) {
port = getSslOptions().getDefaultSslPort();
} else {
port = getTransportOptions().getDefaultTcpPort();
}
}
return port;
}
protected void configureNetty(Bootstrap bootstrap, NettyTransportOptions options) {
bootstrap.option(ChannelOption.TCP_NODELAY, options.isTcpNoDelay());
bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, options.getConnectTimeout());
bootstrap.option(ChannelOption.SO_KEEPALIVE, options.isTcpKeepAlive());
bootstrap.option(ChannelOption.SO_LINGER, options.getSoLinger());
bootstrap.option(ChannelOption.ALLOCATOR, PartialPooledByteBufAllocator.INSTANCE);
if (options.getSendBufferSize() != -1) {
bootstrap.option(ChannelOption.SO_SNDBUF, options.getSendBufferSize());
}
if (options.getReceiveBufferSize() != -1) {
bootstrap.option(ChannelOption.SO_RCVBUF, options.getReceiveBufferSize());
bootstrap.option(ChannelOption.RCVBUF_ALLOCATOR, new FixedRecvByteBufAllocator(options.getReceiveBufferSize()));
}
if (options.getTrafficClass() != -1) {
bootstrap.option(ChannelOption.IP_TOS, options.getTrafficClass());
}
}
protected void configureChannel(final Channel channel) throws Exception {
if (isSSL()) {
SslHandler sslHandler = NettyTransportSupport.createSslHandler(getRemoteLocation(), getSslOptions());
sslHandler.handshakeFuture().addListener(new GenericFutureListener<Future<Channel>>() {
@Override
public void operationComplete(Future<Channel> future) throws Exception {
if (future.isSuccess()) {
LOG.trace("SSL Handshake has completed: {}", channel);
connectionEstablished(channel);
} else {
LOG.trace("SSL Handshake has failed: {}", channel);
connectionFailed(channel, IOExceptionSupport.create(future.cause()));
}
}
});
channel.pipeline().addLast(sslHandler);
}
channel.pipeline().addLast(new HttpClientCodec());
channel.pipeline().addLast(new HttpObjectAggregator(8192));
channel.pipeline().addLast(new NettyTcpTransportHandler());
}
protected void handleConnected(final Channel channel) throws Exception {
if (!isSSL()) {
connectionEstablished(channel);
}
}
//----- State change handlers and checks ---------------------------------//
/**
* Called when the transport has successfully connected and is ready for use.
*/
protected void connectionEstablished(Channel connectedChannel) {
LOG.info("WebSocket connectionEstablished! {}", connectedChannel);
channel = connectedChannel;
connected.set(true);
}
/**
* Called when the transport connection failed and an error should be returned.
*
* @param failedChannel
* The Channel instance that failed.
* @param cause
* An IOException that describes the cause of the failed connection.
*/
protected void connectionFailed(Channel failedChannel, IOException cause) {
failureCause = IOExceptionSupport.create(cause);
channel = failedChannel;
connected.set(false);
handshakeFuture.setFailure(cause);
}
private NettyTransportSslOptions getSslOptions() {
return (NettyTransportSslOptions) getTransportOptions();
}
private void checkConnected() throws IOException {
if (!connected.get()) {
throw new IOException("Cannot send to a non-connected transport.");
}
} }
//----- Handle connection events -----------------------------------------// //----- Handle connection events -----------------------------------------//
private class NettyTcpTransportHandler extends SimpleChannelInboundHandler<Object> { private class NettyWebSocketTransportHandler extends NettyDefaultHandler<Object> {
private final WebSocketClientHandshaker handshaker; private final WebSocketClientHandshaker handshaker;
public NettyTcpTransportHandler() { public NettyWebSocketTransportHandler() {
handshaker = WebSocketClientHandshakerFactory.newHandshaker( handshaker = WebSocketClientHandshakerFactory.newHandshaker(
remote, WebSocketVersion.V13, "amqp", false, new DefaultHttpHeaders()); getRemoteLocation(), WebSocketVersion.V13, AMQP_SUB_PROTOCOL, true, new DefaultHttpHeaders());
}
@Override
public void handlerAdded(ChannelHandlerContext context) {
LOG.trace("Handler has become added! Channel is {}", context.channel());
handshakeFuture = context.newPromise();
} }
@Override @Override
public void channelActive(ChannelHandlerContext context) throws Exception { public void channelActive(ChannelHandlerContext context) throws Exception {
LOG.trace("Channel has become active! Channel is {}", context.channel());
handshaker.handshake(context.channel()); handshaker.handshake(context.channel());
}
@Override super.channelActive(context);
public void channelInactive(ChannelHandlerContext context) throws Exception {
LOG.trace("Channel has gone inactive! Channel is {}", context.channel());
if (connected.compareAndSet(true, false) && !closed.get()) {
LOG.trace("Firing onTransportClosed listener");
listener.onTransportClosed();
}
}
@Override
public void exceptionCaught(ChannelHandlerContext context, Throwable cause) throws Exception {
LOG.trace("Exception on channel! Channel is {} -> {}", context.channel(), cause.getMessage());
LOG.trace("Error Stack: ", cause);
if (connected.compareAndSet(true, false) && !closed.get()) {
LOG.trace("Firing onTransportError listener");
if (pendingFailure != null) {
listener.onTransportError(pendingFailure);
} else {
listener.onTransportError(cause);
}
} else {
// Hold the first failure for later dispatch if connect succeeds.
// This will then trigger disconnect using the first error reported.
if (pendingFailure != null) {
LOG.trace("Holding error until connect succeeds: {}", cause.getMessage());
pendingFailure = cause;
}
if (!handshakeFuture.isDone()) {
handshakeFuture.setFailure(cause);
}
}
} }
@Override @Override
@ -437,8 +131,9 @@ public class NettyWSTransport implements NettyTransport {
Channel ch = ctx.channel(); Channel ch = ctx.channel();
if (!handshaker.isHandshakeComplete()) { if (!handshaker.isHandshakeComplete()) {
handshaker.finishHandshake(ch, (FullHttpResponse) message); handshaker.finishHandshake(ch, (FullHttpResponse) message);
LOG.info("WebSocket Client connected! {}", ctx.channel()); LOG.trace("WebSocket Client connected! {}", ctx.channel());
handshakeFuture.setSuccess(); // Now trigger super processing as we are really connected.
NettyWSTransport.super.handleConnected(ch);
return; return;
} }
@ -447,7 +142,7 @@ public class NettyWSTransport implements NettyTransport {
FullHttpResponse response = (FullHttpResponse) message; FullHttpResponse response = (FullHttpResponse) message;
throw new IllegalStateException( throw new IllegalStateException(
"Unexpected FullHttpResponse (getStatus=" + response.getStatus() + "Unexpected FullHttpResponse (getStatus=" + response.getStatus() +
", content=" + response.content().toString(CharsetUtil.UTF_8) + ')'); ", content=" + response.content().toString(StandardCharsets.UTF_8) + ')');
} }
WebSocketFrame frame = (WebSocketFrame) message; WebSocketFrame frame = (WebSocketFrame) message;
@ -457,10 +152,11 @@ public class NettyWSTransport implements NettyTransport {
ctx.fireExceptionCaught(new IOException("Received invalid frame over WebSocket.")); ctx.fireExceptionCaught(new IOException("Received invalid frame over WebSocket."));
} else if (frame instanceof BinaryWebSocketFrame) { } else if (frame instanceof BinaryWebSocketFrame) {
BinaryWebSocketFrame binaryFrame = (BinaryWebSocketFrame) frame; BinaryWebSocketFrame binaryFrame = (BinaryWebSocketFrame) frame;
LOG.info("WebSocket Client received data: {} bytes", binaryFrame.content().readableBytes()); LOG.trace("WebSocket Client received data: {} bytes", binaryFrame.content().readableBytes());
listener.onData(binaryFrame.content()); listener.onData(binaryFrame.content());
} else if (frame instanceof PongWebSocketFrame) { } else if (frame instanceof PingWebSocketFrame) {
LOG.trace("WebSocket Client received pong"); LOG.trace("WebSocket Client received ping, response with pong");
ch.write(new PongWebSocketFrame(frame.content()));
} else if (frame instanceof CloseWebSocketFrame) { } else if (frame instanceof CloseWebSocketFrame) {
LOG.trace("WebSocket Client received closing"); LOG.trace("WebSocket Client received closing");
ch.close(); ch.close();