diff --git a/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketClient.java b/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketClient.java index cc5db438838..743bdc38b64 100644 --- a/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketClient.java +++ b/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketClient.java @@ -348,8 +348,8 @@ public class WebSocketClient return holder; } - - public static final InetSocketAddress toSocketAddress(URI uri) + + public static InetSocketAddress toSocketAddress(URI uri) { String scheme = uri.getScheme(); if (!("ws".equalsIgnoreCase(scheme) || "wss".equalsIgnoreCase(scheme))) @@ -360,8 +360,7 @@ public class WebSocketClient if (port < 0) port = "ws".equals(scheme) ? 80 : 443; - InetSocketAddress address = new InetSocketAddress(uri.getHost(), port); - return address; + return new InetSocketAddress(uri.getHost(), port); } /* ------------------------------------------------------------ */ @@ -371,16 +370,8 @@ public class WebSocketClient { final WebSocket _websocket; final URI _uri; - final String _protocol; - final String _origin; - final MaskGen _maskGen; - final int _maxIdleTime; - final int _maxTextMessageSize; - final int _maxBinaryMessageSize; - final Map _cookies; - final List _extensions; + final WebSocketClient _client; final CountDownLatch _done = new CountDownLatch(1); - ByteChannel _channel; WebSocketConnection _connection; Throwable _exception; @@ -389,14 +380,7 @@ public class WebSocketClient { _websocket=websocket; _uri=uri; - _protocol=client._protocol; - _origin=client._origin; - _maskGen=client._maskGen; - _maxIdleTime=client._maxIdleTime; - _maxTextMessageSize=client._maxTextMessageSize; - _maxBinaryMessageSize=client._maxBinaryMessageSize; - _cookies=client._cookies; - _extensions=client._extensions; + _client=client; _channel=channel; } @@ -404,8 +388,10 @@ public class WebSocketClient { try { - connection.getConnection().setMaxTextMessageSize(_maxTextMessageSize); - connection.getConnection().setMaxBinaryMessageSize(_maxBinaryMessageSize); + _client.getFactory().addConnection(connection); + + connection.getConnection().setMaxTextMessageSize(_client.getMaxTextMessageSize()); + connection.getConnection().setMaxBinaryMessageSize(_client.getMaxBinaryMessageSize()); WebSocketConnection con; synchronized (this) @@ -460,12 +446,12 @@ public class WebSocketClient public Map getCookies() { - return _cookies; + return _client.getCookies(); } public String getProtocol() { - return _protocol; + return _client.getProtocol(); } public WebSocket getWebSocket() @@ -480,17 +466,17 @@ public class WebSocketClient public int getMaxIdleTime() { - return _maxIdleTime; + return _client.getMaxIdleTime(); } public String getOrigin() { - return _origin; + return _client.getOrigin(); } public MaskGen getMaskGen() { - return _maskGen; + return _client.getMaskGen(); } @Override diff --git a/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketClientFactory.java b/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketClientFactory.java index 35a1f39479c..4b250ab1b2d 100644 --- a/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketClientFactory.java +++ b/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketClientFactory.java @@ -20,8 +20,11 @@ import java.io.IOException; import java.net.ProtocolException; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; +import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; import javax.net.ssl.SSLEngine; import org.eclipse.jetty.http.HttpFields; @@ -33,6 +36,7 @@ import org.eclipse.jetty.io.Buffers; import org.eclipse.jetty.io.ByteArrayBuffer; import org.eclipse.jetty.io.ConnectedEndPoint; import org.eclipse.jetty.io.Connection; +import org.eclipse.jetty.io.EndPoint; import org.eclipse.jetty.io.SimpleBuffers; import org.eclipse.jetty.io.nio.AsyncConnection; import org.eclipse.jetty.io.nio.SelectChannelEndPoint; @@ -60,8 +64,8 @@ public class WebSocketClientFactory extends AggregateLifeCycle { private final static Logger __log = org.eclipse.jetty.util.log.Log.getLogger(WebSocketClientFactory.class.getName()); private final static ByteArrayBuffer __ACCEPT = new ByteArrayBuffer.CaseInsensitive("Sec-WebSocket-Accept"); - - private SslContextFactory _sslContextFactory = new SslContextFactory(); + private final Queue connections = new ConcurrentLinkedQueue(); + private final SslContextFactory _sslContextFactory = new SslContextFactory(); private final ThreadPool _threadPool; private final WebSocketClientSelector _selector; private MaskGen _maskGen; @@ -200,6 +204,12 @@ public class WebSocketClientFactory extends AggregateLifeCycle return _buffers.getBufferSize(); } + @Override + protected void doStop() throws Exception + { + closeConnections(); + } + /* ------------------------------------------------------------ */ /** *

Creates and returns a new instance of a {@link WebSocketClient}, configured with this @@ -231,6 +241,22 @@ public class WebSocketClientFactory extends AggregateLifeCycle return sslEngine; } + protected boolean addConnection(WebSocketConnection connection) + { + return isRunning() && connections.add(connection); + } + + protected boolean removeConnection(WebSocketConnection connection) + { + return connections.remove(connection); + } + + protected void closeConnections() + { + for (WebSocketConnection connection : connections) + connection.shutdown(); + } + /* ------------------------------------------------------------ */ /** * WebSocket Client Selector Manager @@ -457,18 +483,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle } else { - Buffer header = _parser.getHeaderBuffer(); - MaskGen maskGen = _future.getMaskGen(); - WebSocketConnectionRFC6455 connection = - new WebSocketConnectionRFC6455(_future.getWebSocket(), - _endp, - _buffers, System.currentTimeMillis(), - _future.getMaxIdleTime(), - _future.getProtocol(), - null, - WebSocketConnectionRFC6455.VERSION, - maskGen); + WebSocketConnection connection = newWebSocketConnection(); + Buffer header = _parser.getHeaderBuffer(); if (header.hasContent()) connection.fillBuffersFrom(header); _buffers.returnBuffer(header); @@ -483,6 +500,21 @@ public class WebSocketClientFactory extends AggregateLifeCycle return this; } + private WebSocketConnection newWebSocketConnection() throws IOException + { + return new WebSocketClientConnection( + _future._client.getFactory(), + _future.getWebSocket(), + _endp, + _buffers, + System.currentTimeMillis(), + _future.getMaxIdleTime(), + _future.getProtocol(), + null, + WebSocketConnectionRFC6455.VERSION, + _future.getMaskGen()); + } + public void onInputShutdown() throws IOException { _endp.close(); @@ -506,4 +538,22 @@ public class WebSocketClientFactory extends AggregateLifeCycle _future.handshakeFailed(new EOFException()); } } + + private static class WebSocketClientConnection extends WebSocketConnectionRFC6455 + { + private final WebSocketClientFactory factory; + + public WebSocketClientConnection(WebSocketClientFactory factory, WebSocket webSocket, EndPoint endPoint, WebSocketBuffers buffers, long timeStamp, int maxIdleTime, String protocol, List extensions, int draftVersion, MaskGen maskGen) throws IOException + { + super(webSocket, endPoint, buffers, timeStamp, maxIdleTime, protocol, extensions, draftVersion, maskGen); + this.factory = factory; + } + + @Override + public void onClose() + { + super.onClose(); + factory.removeConnection(this); + } + } } diff --git a/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketFactory.java b/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketFactory.java index fcd74a43fcf..9b76e07bfba 100644 --- a/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketFactory.java +++ b/jetty-websocket/src/main/java/org/eclipse/jetty/websocket/WebSocketFactory.java @@ -273,12 +273,12 @@ public class WebSocketFactory extends AbstractLifeCycle } } + addConnection(connection); + // Set the defaults connection.getConnection().setMaxBinaryMessageSize(_maxBinaryMessageSize); connection.getConnection().setMaxTextMessageSize(_maxTextMessageSize); - addConnection(connection); - // Let the connection finish processing the handshake connection.handshake(request, response, protocol); response.flushBuffer(); diff --git a/jetty-websocket/src/test/java/org/eclipse/jetty/websocket/WebSocketRedeployTest.java b/jetty-websocket/src/test/java/org/eclipse/jetty/websocket/WebSocketRedeployTest.java index b3a856de44d..7d5d91b29e5 100644 --- a/jetty-websocket/src/test/java/org/eclipse/jetty/websocket/WebSocketRedeployTest.java +++ b/jetty-websocket/src/test/java/org/eclipse/jetty/websocket/WebSocketRedeployTest.java @@ -126,4 +126,51 @@ public class WebSocketRedeployTest Assert.assertTrue(closeLatch.await(5, TimeUnit.SECONDS)); } + + @Test + public void testStoppingClientFactoryClosesConnections() throws Exception + { + final CountDownLatch openLatch = new CountDownLatch(2); + final CountDownLatch closeLatch = new CountDownLatch(2); + init(new WebSocket.OnTextMessage() + { + public void onOpen(Connection connection) + { + openLatch.countDown(); + } + + public void onMessage(String data) + { + } + + public void onClose(int closeCode, String message) + { + closeLatch.countDown(); + } + }); + + WebSocketClient client = wsFactory.newWebSocketClient(); + client.open(new URI(uri), new WebSocket.OnTextMessage() + { + public void onOpen(Connection connection) + { + openLatch.countDown(); + } + + public void onMessage(String data) + { + } + + public void onClose(int closeCode, String message) + { + closeLatch.countDown(); + } + }, 5, TimeUnit.SECONDS); + + Assert.assertTrue(openLatch.await(5, TimeUnit.SECONDS)); + + wsFactory.stop(); + + Assert.assertTrue(closeLatch.await(5, TimeUnit.SECONDS)); + } }