365750 - Support WebSocket over SSL, aka wss://

This is now implemented, using the new architecture of wrapping the connection with
SslConnection.
The only refactoring was to avoid that the HTTP handshake was sent from the
HandshakeConnection constructor, because at that point the SSL wiring is not ready yet.
Now the handshake is sent from handle(), guarded by a boolean variable to sent it once.
This commit is contained in:
Simone Bordet 2011-12-06 16:25:15 +01:00
parent ba95a9ba3a
commit 0689e05e9b
4 changed files with 296 additions and 132 deletions

View File

@ -320,8 +320,6 @@ public class WebSocketClient
String scheme=uri.getScheme();
if (!("ws".equalsIgnoreCase(scheme) || "wss".equalsIgnoreCase(scheme)))
throw new IllegalArgumentException("Bad WebSocket scheme '"+scheme+"'");
if ("wss".equalsIgnoreCase(scheme))
throw new IOException("wss not supported");
SocketChannel channel = SocketChannel.open();
if (_bindAddress != null)
@ -334,7 +332,7 @@ public class WebSocketClient
channel.configureBlocking(false);
channel.connect(address);
_factory.getSelectorManager().register( channel, holder);
_factory.getSelectorManager().register(channel, holder);
return holder;
}

View File

@ -5,7 +5,9 @@ import java.io.IOException;
import java.net.ProtocolException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.Map;
import java.util.Random;
import javax.net.ssl.SSLEngine;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpParser;
@ -20,30 +22,31 @@ import org.eclipse.jetty.io.SimpleBuffers;
import org.eclipse.jetty.io.nio.AsyncConnection;
import org.eclipse.jetty.io.nio.SelectChannelEndPoint;
import org.eclipse.jetty.io.nio.SelectorManager;
import org.eclipse.jetty.io.nio.SslConnection;
import org.eclipse.jetty.util.B64Code;
import org.eclipse.jetty.util.QuotedStringTokenizer;
import org.eclipse.jetty.util.component.AggregateLifeCycle;
import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.eclipse.jetty.util.thread.ThreadPool;
/* ------------------------------------------------------------ */
/**
* <p>WebSocketClientFactory contains the common components needed by multiple {@link WebSocketClient} instances
* (for example, a {@link ThreadPool}, a {@link SelectorManager NIO selector}, etc).</p>
* <p>WebSocketClients with different configurations should share the same factory to avoid to waste resources.</p>
* <p>If a ThreadPool or MaskGen is passed in the constructor, then it is not added with {@link AggregateLifeCycle#addBean(Object)},
* so it's lifecycle must be controlled externally.
* so it's lifecycle must be controlled externally.
*
* @see WebSocketClient
*/
public class WebSocketClientFactory extends AggregateLifeCycle
{
private final static Logger __log = org.eclipse.jetty.util.log.Log.getLogger(WebSocketClientFactory.class.getName());
private final static Random __random = new Random();
private final static ByteArrayBuffer __ACCEPT = new ByteArrayBuffer.CaseInsensitive("Sec-WebSocket-Accept");
private SslContextFactory _sslContextFactory = new SslContextFactory();
private final ThreadPool _threadPool;
private final WebSocketClientSelector _selector;
private MaskGen _maskGen;
@ -55,54 +58,67 @@ public class WebSocketClientFactory extends AggregateLifeCycle
*/
public WebSocketClientFactory()
{
_threadPool=new QueuedThreadPool();
addBean(_threadPool);
_buffers=new WebSocketBuffers(8*1024);
addBean(_buffers);
_maskGen=new RandomMaskGen();
addBean(_maskGen);
_selector=new WebSocketClientSelector();
addBean(_selector);
this(new QueuedThreadPool());
}
/* ------------------------------------------------------------ */
/**
* <p>Creates a WebSocketClientFactory with the given ThreadPool and the default configuration.</p>
*
* @param threadPool the ThreadPool instance to use
*/
public WebSocketClientFactory(ThreadPool threadPool)
{
_threadPool=threadPool;
addBean(threadPool);
_buffers=new WebSocketBuffers(8*1024);
addBean(_buffers);
_maskGen=new RandomMaskGen();
addBean(_maskGen);
_selector=new WebSocketClientSelector();
addBean(_selector);
this(threadPool, new RandomMaskGen());
}
/* ------------------------------------------------------------ */
/**
* <p>Creates a WebSocketClientFactory with the specified configuration.</p>
* <p>Creates a WebSocketClientFactory with the given ThreadPool and the given MaskGen.</p>
*
* @param threadPool the ThreadPool instance to use
* @param maskGen the mask generator to use
* @param maskGen the MaskGen instance to use
*/
public WebSocketClientFactory(ThreadPool threadPool, MaskGen maskGen)
{
this(threadPool, maskGen, 8192);
}
/* ------------------------------------------------------------ */
/**
* <p>Creates a WebSocketClientFactory with the specified configuration.</p>
*
* @param threadPool the ThreadPool instance to use
* @param maskGen the mask generator to use
* @param bufferSize the read buffer size
*/
public WebSocketClientFactory(ThreadPool threadPool,MaskGen maskGen,int bufferSize)
public WebSocketClientFactory(ThreadPool threadPool, MaskGen maskGen, int bufferSize)
{
_threadPool=threadPool;
_threadPool = threadPool;
addBean(threadPool);
_buffers=new WebSocketBuffers(bufferSize);
_buffers = new WebSocketBuffers(bufferSize);
addBean(_buffers);
_maskGen=maskGen;
_selector=new WebSocketClientSelector();
_maskGen = maskGen;
addBean(_maskGen);
_selector = new WebSocketClientSelector();
addBean(_selector);
addBean(_sslContextFactory);
}
/* ------------------------------------------------------------ */
/**
* @return the SslContextFactory used to configure SSL parameters
*/
public SslContextFactory getSslContextFactory()
{
return _sslContextFactory;
}
/* ------------------------------------------------------------ */
/**
* Get the selectorManager. Used to configure the manager.
*
* @return The {@link SelectorManager} instance.
*/
public SelectorManager getSelectorManager()
@ -111,8 +127,10 @@ public class WebSocketClientFactory extends AggregateLifeCycle
}
/* ------------------------------------------------------------ */
/** Get the ThreadPool.
/**
* Get the ThreadPool.
* Used to set/query the thread pool configuration.
*
* @return The {@link ThreadPool}
*/
public ThreadPool getThreadPool()
@ -139,9 +157,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
{
if (isRunning())
throw new IllegalStateException(getState());
if (removeBean(_maskGen))
addBean(maskGen);
_maskGen=maskGen;
removeBean(_maskGen);
_maskGen = maskGen;
addBean(maskGen);
}
/* ------------------------------------------------------------ */
@ -154,7 +172,7 @@ public class WebSocketClientFactory extends AggregateLifeCycle
if (isRunning())
throw new IllegalStateException(getState());
removeBean(_buffers);
_buffers=new WebSocketBuffers(bufferSize);
_buffers = new WebSocketBuffers(bufferSize);
addBean(_buffers);
}
@ -179,24 +197,28 @@ public class WebSocketClientFactory extends AggregateLifeCycle
return new WebSocketClient(this);
}
/* ------------------------------------------------------------ */
@Override
protected void doStart() throws Exception
protected SSLEngine newSslEngine(SocketChannel channel) throws IOException
{
super.doStart();
if (getThreadPool() instanceof LifeCycle && !((LifeCycle)getThreadPool()).isStarted())
((LifeCycle)getThreadPool()).start();
SSLEngine sslEngine;
if (channel != null)
{
String peerHost = channel.socket().getInetAddress().getHostAddress();
int peerPort = channel.socket().getPort();
sslEngine = _sslContextFactory.newSslEngine(peerHost, peerPort);
}
else
{
sslEngine = _sslContextFactory.newSslEngine();
}
sslEngine.setUseClientMode(true);
sslEngine.beginHandshake();
return sslEngine;
}
/* ------------------------------------------------------------ */
@Override
protected void doStop() throws Exception
{
super.doStop();
}
/* ------------------------------------------------------------ */
/** WebSocket Client Selector Manager
/**
* WebSocket Client Selector Manager
*/
class WebSocketClientSelector extends SelectorManager
{
@ -209,16 +231,33 @@ public class WebSocketClientFactory extends AggregateLifeCycle
@Override
protected SelectChannelEndPoint newEndPoint(SocketChannel channel, SelectSet selectSet, final SelectionKey key) throws IOException
{
SelectChannelEndPoint endp= new SelectChannelEndPoint(channel,selectSet,key,channel.socket().getSoTimeout());
endp.setConnection(selectSet.getManager().newConnection(channel,endp, key.attachment()));
return endp;
WebSocketClient.WebSocketFuture holder = (WebSocketClient.WebSocketFuture)key.attachment();
int maxIdleTime = holder.getMaxIdleTime();
if (maxIdleTime < 0)
maxIdleTime = (int)getMaxIdleTime();
SelectChannelEndPoint result = new SelectChannelEndPoint(channel, selectSet, key, maxIdleTime);
AsyncEndPoint endPoint = result;
// Detect if it is SSL, and wrap the connection if so
if ("wss".equals(holder.getURI().getScheme()))
{
SSLEngine sslEngine = newSslEngine(channel);
SslConnection sslConnection = new SslConnection(sslEngine, endPoint);
endPoint.setConnection(sslConnection);
endPoint = sslConnection.getSslEndPoint();
}
AsyncConnection connection = selectSet.getManager().newConnection(channel, endPoint, holder);
endPoint.setConnection(connection);
return result;
}
@Override
public AsyncConnection newConnection(SocketChannel channel, AsyncEndPoint endpoint, Object attachment)
{
WebSocketClient.WebSocketFuture holder = (WebSocketClient.WebSocketFuture) attachment;
return new HandshakeConnection(endpoint,holder);
WebSocketClient.WebSocketFuture holder = (WebSocketClient.WebSocketFuture)attachment;
return new HandshakeConnection(endpoint, holder);
}
@Override
@ -230,7 +269,7 @@ public class WebSocketClientFactory extends AggregateLifeCycle
@Override
protected void endPointUpgraded(ConnectedEndPoint endpoint, Connection oldConnection)
{
LOG.debug("upgrade {} -> {}",oldConnection,endpoint.getConnection());
LOG.debug("upgrade {} -> {}", oldConnection, endpoint.getConnection());
}
@Override
@ -243,7 +282,7 @@ public class WebSocketClientFactory extends AggregateLifeCycle
protected void connectionFailed(SocketChannel channel, Throwable ex, Object attachment)
{
if (!(attachment instanceof WebSocketClient.WebSocketFuture))
super.connectionFailed(channel,ex,attachment);
super.connectionFailed(channel, ex, attachment);
else
{
__log.debug(ex);
@ -254,9 +293,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
}
}
/* ------------------------------------------------------------ */
/** Handshake Connection.
/**
* Handshake Connection.
* Handles the connection until the handshake succeeds or fails.
*/
class HandshakeConnection extends AbstractConnection implements AsyncConnection
@ -267,29 +306,27 @@ public class WebSocketClientFactory extends AggregateLifeCycle
private final HttpParser _parser;
private String _accept;
private String _error;
private boolean _handshaken;
public HandshakeConnection(AsyncEndPoint endpoint, WebSocketClient.WebSocketFuture future)
{
super(endpoint,System.currentTimeMillis());
_endp=endpoint;
_future=future;
super(endpoint, System.currentTimeMillis());
_endp = endpoint;
_future = future;
byte[] bytes=new byte[16];
__random.nextBytes(bytes);
_key=new String(B64Code.encode(bytes));
byte[] bytes = new byte[16];
new Random().nextBytes(bytes);
_key = new String(B64Code.encode(bytes));
Buffers buffers = new SimpleBuffers(_buffers.getBuffer(),null);
_parser=new HttpParser(buffers,_endp,
new HttpParser.EventHandler()
Buffers buffers = new SimpleBuffers(_buffers.getBuffer(), null);
_parser = new HttpParser(buffers, _endp, new HttpParser.EventHandler()
{
@Override
public void startResponse(Buffer version, int status, Buffer reason) throws IOException
{
if (status!=101)
if (status != 101)
{
_error="Bad response status "+status+" "+reason;
_error = "Bad response status " + status + " " + reason;
_endp.close();
}
}
@ -298,65 +335,64 @@ public class WebSocketClientFactory extends AggregateLifeCycle
public void parsedHeader(Buffer name, Buffer value) throws IOException
{
if (__ACCEPT.equals(name))
_accept=value.toString();
_accept = value.toString();
}
@Override
public void startRequest(Buffer method, Buffer url, Buffer version) throws IOException
{
if (_error==null)
_error="Bad response: "+method+" "+url+" "+version;
if (_error == null)
_error = "Bad response: " + method + " " + url + " " + version;
_endp.close();
}
@Override
public void content(Buffer ref) throws IOException
{
if (_error==null)
_error="Bad response. "+ref.length()+"B of content?";
if (_error == null)
_error = "Bad response. " + ref.length() + "B of content?";
_endp.close();
}
});
}
String path=_future.getURI().getPath();
if (path==null || path.length()==0)
{
path="/";
}
if(_future.getURI().getRawQuery() != null)
{
private void handshake()
{
String path = _future.getURI().getPath();
if (path == null || path.length() == 0)
path = "/";
if (_future.getURI().getRawQuery() != null)
path += "?" + _future.getURI().getRawQuery();
}
String origin = future.getOrigin();
String origin = _future.getOrigin();
StringBuilder request = new StringBuilder(512);
request
.append("GET ").append(path).append(" HTTP/1.1\r\n")
.append("Host: ").append(future.getURI().getHost()).append(":").append(_future.getURI().getPort()).append("\r\n")
.append("Upgrade: websocket\r\n")
.append("Connection: Upgrade\r\n")
.append("Sec-WebSocket-Key: ")
.append(_key).append("\r\n");
if(origin!=null)
request.append("GET ").append(path).append(" HTTP/1.1\r\n")
.append("Host: ").append(_future.getURI().getHost()).append(":")
.append(_future.getURI().getPort()).append("\r\n")
.append("Upgrade: websocket\r\n")
.append("Connection: Upgrade\r\n")
.append("Sec-WebSocket-Key: ")
.append(_key).append("\r\n");
if (origin != null)
request.append("Origin: ").append(origin).append("\r\n");
request.append("Sec-WebSocket-Version: ").append(WebSocketConnectionD13.VERSION).append("\r\n");
if (future.getProtocol()!=null)
request.append("Sec-WebSocket-Protocol: ").append(future.getProtocol()).append("\r\n");
if (_future.getProtocol() != null)
request.append("Sec-WebSocket-Protocol: ").append(_future.getProtocol()).append("\r\n");
if (future.getCookies()!=null && future.getCookies().size()>0)
Map<String, String> cookies = _future.getCookies();
if (cookies != null && cookies.size() > 0)
{
for (String cookie : future.getCookies().keySet())
request
.append("Cookie: ")
.append(QuotedStringTokenizer.quoteIfNeeded(cookie,HttpFields.__COOKIE_DELIM))
.append("=")
.append(QuotedStringTokenizer.quoteIfNeeded(future.getCookies().get(cookie),HttpFields.__COOKIE_DELIM))
.append("\r\n");
for (String cookie : cookies.keySet())
request.append("Cookie: ")
.append(QuotedStringTokenizer.quoteIfNeeded(cookie, HttpFields.__COOKIE_DELIM))
.append("=")
.append(QuotedStringTokenizer.quoteIfNeeded(cookies.get(cookie), HttpFields.__COOKIE_DELIM))
.append("\r\n");
}
request.append("\r\n");
@ -365,14 +401,18 @@ public class WebSocketClientFactory extends AggregateLifeCycle
try
{
Buffer handshake = new ByteArrayBuffer(request.toString(),false);
int len=handshake.length();
if (len!=_endp.flush(handshake))
Buffer handshake = new ByteArrayBuffer(request.toString(), false);
int len = handshake.length();
if (len != _endp.flush(handshake))
throw new IOException("incomplete");
}
catch(IOException e)
catch (IOException e)
{
future.handshakeFailed(e);
_future.handshakeFailed(e);
}
finally
{
_handshaken = true;
}
}
@ -380,6 +420,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
{
while (_endp.isOpen() && !_parser.isComplete())
{
if (!_handshaken)
handshake();
if (!_parser.parseAvailable())
{
if (_endp.isInputShutdown())
@ -387,25 +430,29 @@ public class WebSocketClientFactory extends AggregateLifeCycle
return this;
}
}
if (_error==null)
if (_error == null)
{
if (_accept==null)
_error="No Sec-WebSocket-Accept";
if (_accept == null)
{
_error = "No Sec-WebSocket-Accept";
}
else if (!WebSocketConnectionD13.hashKey(_key).equals(_accept))
_error="Bad Sec-WebSocket-Accept";
{
_error = "Bad Sec-WebSocket-Accept";
}
else
{
Buffer header=_parser.getHeaderBuffer();
MaskGen maskGen=_future.getMaskGen();
WebSocketConnectionD13 connection =
new WebSocketConnectionD13(_future.getWebSocket(),
_endp,
_buffers,System.currentTimeMillis(),
_future.getMaxIdleTime(),
_future.getProtocol(),
null,
WebSocketConnectionD13.VERSION,
maskGen);
Buffer header = _parser.getHeaderBuffer();
MaskGen maskGen = _future.getMaskGen();
WebSocketConnectionD13 connection =
new WebSocketConnectionD13(_future.getWebSocket(),
_endp,
_buffers, System.currentTimeMillis(),
_future.getMaxIdleTime(),
_future.getProtocol(),
null,
WebSocketConnectionD13.VERSION,
maskGen);
if (header.hasContent())
connection.fillBuffersFrom(header);
@ -438,15 +485,10 @@ public class WebSocketClientFactory extends AggregateLifeCycle
public void onClose()
{
if (_error!=null)
if (_error != null)
_future.handshakeFailed(new ProtocolException(_error));
else
_future.handshakeFailed(new EOFException());
}
public String toString()
{
return "HS"+super.toString();
}
}
}

View File

@ -0,0 +1,124 @@
package org.eclipse.jetty.websocket;
import java.io.IOException;
import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ssl.SslSelectChannelConnector;
import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
public class WebSocketOverSSLTest
{
private Server _server;
private int _port;
private WebSocket.Connection _connection;
private void startServer(final WebSocket webSocket) throws Exception
{
_server = new Server();
SslSelectChannelConnector connector = new SslSelectChannelConnector();
_server.addConnector(connector);
SslContextFactory cf = connector.getSslContextFactory();
cf.setKeyStorePath(MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath());
cf.setKeyStorePassword("storepwd");
cf.setKeyManagerPassword("keypwd");
_server.setHandler(new WebSocketHandler()
{
public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol)
{
return webSocket;
}
});
_server.start();
_port = connector.getLocalPort();
}
private void startClient(final WebSocket webSocket) throws Exception
{
Assert.assertTrue(_server.isStarted());
WebSocketClientFactory factory = new WebSocketClientFactory(new QueuedThreadPool(), new ZeroMaskGen());
SslContextFactory cf = factory.getSslContextFactory();
cf.setKeyStorePath(MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath());
cf.setKeyStorePassword("storepwd");
cf.setKeyManagerPassword("keypwd");
factory.start();
WebSocketClient client = new WebSocketClient(factory);
_connection = client.open(new URI("wss://localhost:" + _port), webSocket).get(5, TimeUnit.SECONDS);
}
@After
public void destroy() throws Exception
{
if (_connection != null)
_connection.close();
if (_server != null)
{
_server.stop();
_server.join();
}
}
@Test
public void testWebSocketOverSSL() throws Exception
{
final String message = "message";
final CountDownLatch serverLatch = new CountDownLatch(1);
startServer(new WebSocket.OnTextMessage()
{
private Connection connection;
public void onOpen(Connection connection)
{
this.connection = connection;
}
public void onMessage(String data)
{
try
{
Assert.assertEquals(message, data);
connection.sendMessage(data);
serverLatch.countDown();
}
catch (IOException x)
{
x.printStackTrace();
}
}
public void onClose(int closeCode, String message)
{
}
});
final CountDownLatch clientLatch = new CountDownLatch(1);
startClient(new WebSocket.OnTextMessage()
{
public void onOpen(Connection connection)
{
}
public void onMessage(String data)
{
Assert.assertEquals(message, data);
clientLatch.countDown();
}
public void onClose(int closeCode, String message)
{
}
});
_connection.sendMessage(message);
Assert.assertTrue(serverLatch.await(5, TimeUnit.SECONDS));
Assert.assertTrue(clientLatch.await(5, TimeUnit.SECONDS));
}
}

Binary file not shown.