367502 - WebSocket connections should be closed when application context is stopped.

This commit is contained in:
Simone Bordet 2011-12-23 23:43:02 +01:00
parent 62f47e0617
commit 3cda41248e
4 changed files with 126 additions and 43 deletions

View File

@ -348,8 +348,8 @@ public class WebSocketClient
return holder; return holder;
} }
public static final InetSocketAddress toSocketAddress(URI uri) public static InetSocketAddress toSocketAddress(URI uri)
{ {
String scheme = uri.getScheme(); String scheme = uri.getScheme();
if (!("ws".equalsIgnoreCase(scheme) || "wss".equalsIgnoreCase(scheme))) if (!("ws".equalsIgnoreCase(scheme) || "wss".equalsIgnoreCase(scheme)))
@ -360,8 +360,7 @@ public class WebSocketClient
if (port < 0) if (port < 0)
port = "ws".equals(scheme) ? 80 : 443; port = "ws".equals(scheme) ? 80 : 443;
InetSocketAddress address = new InetSocketAddress(uri.getHost(), port); return new InetSocketAddress(uri.getHost(), port);
return address;
} }
/* ------------------------------------------------------------ */ /* ------------------------------------------------------------ */
@ -371,16 +370,8 @@ public class WebSocketClient
{ {
final WebSocket _websocket; final WebSocket _websocket;
final URI _uri; final URI _uri;
final String _protocol; final WebSocketClient _client;
final String _origin;
final MaskGen _maskGen;
final int _maxIdleTime;
final int _maxTextMessageSize;
final int _maxBinaryMessageSize;
final Map<String,String> _cookies;
final List<String> _extensions;
final CountDownLatch _done = new CountDownLatch(1); final CountDownLatch _done = new CountDownLatch(1);
ByteChannel _channel; ByteChannel _channel;
WebSocketConnection _connection; WebSocketConnection _connection;
Throwable _exception; Throwable _exception;
@ -389,14 +380,7 @@ public class WebSocketClient
{ {
_websocket=websocket; _websocket=websocket;
_uri=uri; _uri=uri;
_protocol=client._protocol; _client=client;
_origin=client._origin;
_maskGen=client._maskGen;
_maxIdleTime=client._maxIdleTime;
_maxTextMessageSize=client._maxTextMessageSize;
_maxBinaryMessageSize=client._maxBinaryMessageSize;
_cookies=client._cookies;
_extensions=client._extensions;
_channel=channel; _channel=channel;
} }
@ -404,8 +388,10 @@ public class WebSocketClient
{ {
try try
{ {
connection.getConnection().setMaxTextMessageSize(_maxTextMessageSize); _client.getFactory().addConnection(connection);
connection.getConnection().setMaxBinaryMessageSize(_maxBinaryMessageSize);
connection.getConnection().setMaxTextMessageSize(_client.getMaxTextMessageSize());
connection.getConnection().setMaxBinaryMessageSize(_client.getMaxBinaryMessageSize());
WebSocketConnection con; WebSocketConnection con;
synchronized (this) synchronized (this)
@ -460,12 +446,12 @@ public class WebSocketClient
public Map<String,String> getCookies() public Map<String,String> getCookies()
{ {
return _cookies; return _client.getCookies();
} }
public String getProtocol() public String getProtocol()
{ {
return _protocol; return _client.getProtocol();
} }
public WebSocket getWebSocket() public WebSocket getWebSocket()
@ -480,17 +466,17 @@ public class WebSocketClient
public int getMaxIdleTime() public int getMaxIdleTime()
{ {
return _maxIdleTime; return _client.getMaxIdleTime();
} }
public String getOrigin() public String getOrigin()
{ {
return _origin; return _client.getOrigin();
} }
public MaskGen getMaskGen() public MaskGen getMaskGen()
{ {
return _maskGen; return _client.getMaskGen();
} }
@Override @Override

View File

@ -20,8 +20,11 @@ import java.io.IOException;
import java.net.ProtocolException; import java.net.ProtocolException;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Queue;
import java.util.Random; import java.util.Random;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.http.HttpFields;
@ -33,6 +36,7 @@ import org.eclipse.jetty.io.Buffers;
import org.eclipse.jetty.io.ByteArrayBuffer; import org.eclipse.jetty.io.ByteArrayBuffer;
import org.eclipse.jetty.io.ConnectedEndPoint; import org.eclipse.jetty.io.ConnectedEndPoint;
import org.eclipse.jetty.io.Connection; import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.io.SimpleBuffers; import org.eclipse.jetty.io.SimpleBuffers;
import org.eclipse.jetty.io.nio.AsyncConnection; import org.eclipse.jetty.io.nio.AsyncConnection;
import org.eclipse.jetty.io.nio.SelectChannelEndPoint; import org.eclipse.jetty.io.nio.SelectChannelEndPoint;
@ -60,8 +64,8 @@ public class WebSocketClientFactory extends AggregateLifeCycle
{ {
private final static Logger __log = org.eclipse.jetty.util.log.Log.getLogger(WebSocketClientFactory.class.getName()); private final static Logger __log = org.eclipse.jetty.util.log.Log.getLogger(WebSocketClientFactory.class.getName());
private final static ByteArrayBuffer __ACCEPT = new ByteArrayBuffer.CaseInsensitive("Sec-WebSocket-Accept"); private final static ByteArrayBuffer __ACCEPT = new ByteArrayBuffer.CaseInsensitive("Sec-WebSocket-Accept");
private final Queue<WebSocketConnection> connections = new ConcurrentLinkedQueue<WebSocketConnection>();
private SslContextFactory _sslContextFactory = new SslContextFactory(); private final SslContextFactory _sslContextFactory = new SslContextFactory();
private final ThreadPool _threadPool; private final ThreadPool _threadPool;
private final WebSocketClientSelector _selector; private final WebSocketClientSelector _selector;
private MaskGen _maskGen; private MaskGen _maskGen;
@ -200,6 +204,12 @@ public class WebSocketClientFactory extends AggregateLifeCycle
return _buffers.getBufferSize(); return _buffers.getBufferSize();
} }
@Override
protected void doStop() throws Exception
{
closeConnections();
}
/* ------------------------------------------------------------ */ /* ------------------------------------------------------------ */
/** /**
* <p>Creates and returns a new instance of a {@link WebSocketClient}, configured with this * <p>Creates and returns a new instance of a {@link WebSocketClient}, configured with this
@ -231,6 +241,22 @@ public class WebSocketClientFactory extends AggregateLifeCycle
return sslEngine; return sslEngine;
} }
protected boolean addConnection(WebSocketConnection connection)
{
return isRunning() && connections.add(connection);
}
protected boolean removeConnection(WebSocketConnection connection)
{
return connections.remove(connection);
}
protected void closeConnections()
{
for (WebSocketConnection connection : connections)
connection.shutdown();
}
/* ------------------------------------------------------------ */ /* ------------------------------------------------------------ */
/** /**
* WebSocket Client Selector Manager * WebSocket Client Selector Manager
@ -457,18 +483,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
} }
else else
{ {
Buffer header = _parser.getHeaderBuffer(); WebSocketConnection connection = newWebSocketConnection();
MaskGen maskGen = _future.getMaskGen();
WebSocketConnectionRFC6455 connection =
new WebSocketConnectionRFC6455(_future.getWebSocket(),
_endp,
_buffers, System.currentTimeMillis(),
_future.getMaxIdleTime(),
_future.getProtocol(),
null,
WebSocketConnectionRFC6455.VERSION,
maskGen);
Buffer header = _parser.getHeaderBuffer();
if (header.hasContent()) if (header.hasContent())
connection.fillBuffersFrom(header); connection.fillBuffersFrom(header);
_buffers.returnBuffer(header); _buffers.returnBuffer(header);
@ -483,6 +500,21 @@ public class WebSocketClientFactory extends AggregateLifeCycle
return this; return this;
} }
private WebSocketConnection newWebSocketConnection() throws IOException
{
return new WebSocketClientConnection(
_future._client.getFactory(),
_future.getWebSocket(),
_endp,
_buffers,
System.currentTimeMillis(),
_future.getMaxIdleTime(),
_future.getProtocol(),
null,
WebSocketConnectionRFC6455.VERSION,
_future.getMaskGen());
}
public void onInputShutdown() throws IOException public void onInputShutdown() throws IOException
{ {
_endp.close(); _endp.close();
@ -506,4 +538,22 @@ public class WebSocketClientFactory extends AggregateLifeCycle
_future.handshakeFailed(new EOFException()); _future.handshakeFailed(new EOFException());
} }
} }
private static class WebSocketClientConnection extends WebSocketConnectionRFC6455
{
private final WebSocketClientFactory factory;
public WebSocketClientConnection(WebSocketClientFactory factory, WebSocket webSocket, EndPoint endPoint, WebSocketBuffers buffers, long timeStamp, int maxIdleTime, String protocol, List<Extension> extensions, int draftVersion, MaskGen maskGen) throws IOException
{
super(webSocket, endPoint, buffers, timeStamp, maxIdleTime, protocol, extensions, draftVersion, maskGen);
this.factory = factory;
}
@Override
public void onClose()
{
super.onClose();
factory.removeConnection(this);
}
}
} }

View File

@ -273,12 +273,12 @@ public class WebSocketFactory extends AbstractLifeCycle
} }
} }
addConnection(connection);
// Set the defaults // Set the defaults
connection.getConnection().setMaxBinaryMessageSize(_maxBinaryMessageSize); connection.getConnection().setMaxBinaryMessageSize(_maxBinaryMessageSize);
connection.getConnection().setMaxTextMessageSize(_maxTextMessageSize); connection.getConnection().setMaxTextMessageSize(_maxTextMessageSize);
addConnection(connection);
// Let the connection finish processing the handshake // Let the connection finish processing the handshake
connection.handshake(request, response, protocol); connection.handshake(request, response, protocol);
response.flushBuffer(); response.flushBuffer();

View File

@ -126,4 +126,51 @@ public class WebSocketRedeployTest
Assert.assertTrue(closeLatch.await(5, TimeUnit.SECONDS)); Assert.assertTrue(closeLatch.await(5, TimeUnit.SECONDS));
} }
@Test
public void testStoppingClientFactoryClosesConnections() throws Exception
{
final CountDownLatch openLatch = new CountDownLatch(2);
final CountDownLatch closeLatch = new CountDownLatch(2);
init(new WebSocket.OnTextMessage()
{
public void onOpen(Connection connection)
{
openLatch.countDown();
}
public void onMessage(String data)
{
}
public void onClose(int closeCode, String message)
{
closeLatch.countDown();
}
});
WebSocketClient client = wsFactory.newWebSocketClient();
client.open(new URI(uri), new WebSocket.OnTextMessage()
{
public void onOpen(Connection connection)
{
openLatch.countDown();
}
public void onMessage(String data)
{
}
public void onClose(int closeCode, String message)
{
closeLatch.countDown();
}
}, 5, TimeUnit.SECONDS);
Assert.assertTrue(openLatch.await(5, TimeUnit.SECONDS));
wsFactory.stop();
Assert.assertTrue(closeLatch.await(5, TimeUnit.SECONDS));
}
} }