365750 - Support WebSocket over SSL, aka wss://

This is now implemented, using the new architecture of wrapping the connection with
SslConnection.
The only refactoring was to avoid that the HTTP handshake was sent from the
HandshakeConnection constructor, because at that point the SSL wiring is not ready yet.
Now the handshake is sent from handle(), guarded by a boolean variable to sent it once.
This commit is contained in:
Simone Bordet 2011-12-06 16:25:15 +01:00
parent ba95a9ba3a
commit 0689e05e9b
4 changed files with 296 additions and 132 deletions

View File

@ -320,8 +320,6 @@ public class WebSocketClient
String scheme=uri.getScheme();
if (!("ws".equalsIgnoreCase(scheme) || "wss".equalsIgnoreCase(scheme)))
throw new IllegalArgumentException("Bad WebSocket scheme '"+scheme+"'");
if ("wss".equalsIgnoreCase(scheme))
throw new IOException("wss not supported");
SocketChannel channel = SocketChannel.open();
if (_bindAddress != null)

View File

@ -5,7 +5,9 @@ import java.io.IOException;
import java.net.ProtocolException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.Map;
import java.util.Random;
import javax.net.ssl.SSLEngine;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpParser;
@ -20,15 +22,15 @@ import org.eclipse.jetty.io.SimpleBuffers;
import org.eclipse.jetty.io.nio.AsyncConnection;
import org.eclipse.jetty.io.nio.SelectChannelEndPoint;
import org.eclipse.jetty.io.nio.SelectorManager;
import org.eclipse.jetty.io.nio.SslConnection;
import org.eclipse.jetty.util.B64Code;
import org.eclipse.jetty.util.QuotedStringTokenizer;
import org.eclipse.jetty.util.component.AggregateLifeCycle;
import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.eclipse.jetty.util.thread.ThreadPool;
/* ------------------------------------------------------------ */
/**
* <p>WebSocketClientFactory contains the common components needed by multiple {@link WebSocketClient} instances
@ -36,14 +38,15 @@ import org.eclipse.jetty.util.thread.ThreadPool;
* <p>WebSocketClients with different configurations should share the same factory to avoid to waste resources.</p>
* <p>If a ThreadPool or MaskGen is passed in the constructor, then it is not added with {@link AggregateLifeCycle#addBean(Object)},
* so it's lifecycle must be controlled externally.
*
* @see WebSocketClient
*/
public class WebSocketClientFactory extends AggregateLifeCycle
{
private final static Logger __log = org.eclipse.jetty.util.log.Log.getLogger(WebSocketClientFactory.class.getName());
private final static Random __random = new Random();
private final static ByteArrayBuffer __ACCEPT = new ByteArrayBuffer.CaseInsensitive("Sec-WebSocket-Accept");
private SslContextFactory _sslContextFactory = new SslContextFactory();
private final ThreadPool _threadPool;
private final WebSocketClientSelector _selector;
private MaskGen _maskGen;
@ -55,36 +58,37 @@ public class WebSocketClientFactory extends AggregateLifeCycle
*/
public WebSocketClientFactory()
{
_threadPool=new QueuedThreadPool();
addBean(_threadPool);
_buffers=new WebSocketBuffers(8*1024);
addBean(_buffers);
_maskGen=new RandomMaskGen();
addBean(_maskGen);
_selector=new WebSocketClientSelector();
addBean(_selector);
this(new QueuedThreadPool());
}
/* ------------------------------------------------------------ */
/**
* <p>Creates a WebSocketClientFactory with the given ThreadPool and the default configuration.</p>
*
* @param threadPool the ThreadPool instance to use
*/
public WebSocketClientFactory(ThreadPool threadPool)
{
_threadPool=threadPool;
addBean(threadPool);
_buffers=new WebSocketBuffers(8*1024);
addBean(_buffers);
_maskGen=new RandomMaskGen();
addBean(_maskGen);
_selector=new WebSocketClientSelector();
addBean(_selector);
this(threadPool, new RandomMaskGen());
}
/* ------------------------------------------------------------ */
/**
* <p>Creates a WebSocketClientFactory with the given ThreadPool and the given MaskGen.</p>
*
* @param threadPool the ThreadPool instance to use
* @param maskGen the MaskGen instance to use
*/
public WebSocketClientFactory(ThreadPool threadPool, MaskGen maskGen)
{
this(threadPool, maskGen, 8192);
}
/* ------------------------------------------------------------ */
/**
* <p>Creates a WebSocketClientFactory with the specified configuration.</p>
*
* @param threadPool the ThreadPool instance to use
* @param maskGen the mask generator to use
* @param bufferSize the read buffer size
@ -96,13 +100,25 @@ public class WebSocketClientFactory extends AggregateLifeCycle
_buffers = new WebSocketBuffers(bufferSize);
addBean(_buffers);
_maskGen = maskGen;
addBean(_maskGen);
_selector = new WebSocketClientSelector();
addBean(_selector);
addBean(_sslContextFactory);
}
/* ------------------------------------------------------------ */
/**
* @return the SslContextFactory used to configure SSL parameters
*/
public SslContextFactory getSslContextFactory()
{
return _sslContextFactory;
}
/* ------------------------------------------------------------ */
/**
* Get the selectorManager. Used to configure the manager.
*
* @return The {@link SelectorManager} instance.
*/
public SelectorManager getSelectorManager()
@ -111,8 +127,10 @@ public class WebSocketClientFactory extends AggregateLifeCycle
}
/* ------------------------------------------------------------ */
/** Get the ThreadPool.
/**
* Get the ThreadPool.
* Used to set/query the thread pool configuration.
*
* @return The {@link ThreadPool}
*/
public ThreadPool getThreadPool()
@ -139,9 +157,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
{
if (isRunning())
throw new IllegalStateException(getState());
if (removeBean(_maskGen))
addBean(maskGen);
removeBean(_maskGen);
_maskGen = maskGen;
addBean(maskGen);
}
/* ------------------------------------------------------------ */
@ -179,24 +197,28 @@ public class WebSocketClientFactory extends AggregateLifeCycle
return new WebSocketClient(this);
}
/* ------------------------------------------------------------ */
@Override
protected void doStart() throws Exception
protected SSLEngine newSslEngine(SocketChannel channel) throws IOException
{
super.doStart();
if (getThreadPool() instanceof LifeCycle && !((LifeCycle)getThreadPool()).isStarted())
((LifeCycle)getThreadPool()).start();
SSLEngine sslEngine;
if (channel != null)
{
String peerHost = channel.socket().getInetAddress().getHostAddress();
int peerPort = channel.socket().getPort();
sslEngine = _sslContextFactory.newSslEngine(peerHost, peerPort);
}
else
{
sslEngine = _sslContextFactory.newSslEngine();
}
sslEngine.setUseClientMode(true);
sslEngine.beginHandshake();
return sslEngine;
}
/* ------------------------------------------------------------ */
@Override
protected void doStop() throws Exception
{
super.doStop();
}
/* ------------------------------------------------------------ */
/** WebSocket Client Selector Manager
/**
* WebSocket Client Selector Manager
*/
class WebSocketClientSelector extends SelectorManager
{
@ -209,9 +231,26 @@ public class WebSocketClientFactory extends AggregateLifeCycle
@Override
protected SelectChannelEndPoint newEndPoint(SocketChannel channel, SelectSet selectSet, final SelectionKey key) throws IOException
{
SelectChannelEndPoint endp= new SelectChannelEndPoint(channel,selectSet,key,channel.socket().getSoTimeout());
endp.setConnection(selectSet.getManager().newConnection(channel,endp, key.attachment()));
return endp;
WebSocketClient.WebSocketFuture holder = (WebSocketClient.WebSocketFuture)key.attachment();
int maxIdleTime = holder.getMaxIdleTime();
if (maxIdleTime < 0)
maxIdleTime = (int)getMaxIdleTime();
SelectChannelEndPoint result = new SelectChannelEndPoint(channel, selectSet, key, maxIdleTime);
AsyncEndPoint endPoint = result;
// Detect if it is SSL, and wrap the connection if so
if ("wss".equals(holder.getURI().getScheme()))
{
SSLEngine sslEngine = newSslEngine(channel);
SslConnection sslConnection = new SslConnection(sslEngine, endPoint);
endPoint.setConnection(sslConnection);
endPoint = sslConnection.getSslEndPoint();
}
AsyncConnection connection = selectSet.getManager().newConnection(channel, endPoint, holder);
endPoint.setConnection(connection);
return result;
}
@Override
@ -254,9 +293,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
}
}
/* ------------------------------------------------------------ */
/** Handshake Connection.
/**
* Handshake Connection.
* Handles the connection until the handshake succeeds or fails.
*/
class HandshakeConnection extends AbstractConnection implements AsyncConnection
@ -267,6 +306,7 @@ public class WebSocketClientFactory extends AggregateLifeCycle
private final HttpParser _parser;
private String _accept;
private String _error;
private boolean _handshaken;
public HandshakeConnection(AsyncEndPoint endpoint, WebSocketClient.WebSocketFuture future)
{
@ -275,14 +315,11 @@ public class WebSocketClientFactory extends AggregateLifeCycle
_future = future;
byte[] bytes = new byte[16];
__random.nextBytes(bytes);
new Random().nextBytes(bytes);
_key = new String(B64Code.encode(bytes));
Buffers buffers = new SimpleBuffers(_buffers.getBuffer(), null);
_parser=new HttpParser(buffers,_endp,
new HttpParser.EventHandler()
_parser = new HttpParser(buffers, _endp, new HttpParser.EventHandler()
{
@Override
public void startResponse(Buffer version, int status, Buffer reason) throws IOException
@ -317,24 +354,23 @@ public class WebSocketClientFactory extends AggregateLifeCycle
_endp.close();
}
});
}
private void handshake()
{
String path = _future.getURI().getPath();
if (path == null || path.length() == 0)
{
path = "/";
}
if (_future.getURI().getRawQuery() != null)
{
path += "?" + _future.getURI().getRawQuery();
}
String origin = future.getOrigin();
String origin = _future.getOrigin();
StringBuilder request = new StringBuilder(512);
request
.append("GET ").append(path).append(" HTTP/1.1\r\n")
.append("Host: ").append(future.getURI().getHost()).append(":").append(_future.getURI().getPort()).append("\r\n")
request.append("GET ").append(path).append(" HTTP/1.1\r\n")
.append("Host: ").append(_future.getURI().getHost()).append(":")
.append(_future.getURI().getPort()).append("\r\n")
.append("Upgrade: websocket\r\n")
.append("Connection: Upgrade\r\n")
.append("Sec-WebSocket-Key: ")
@ -345,17 +381,17 @@ public class WebSocketClientFactory extends AggregateLifeCycle
request.append("Sec-WebSocket-Version: ").append(WebSocketConnectionD13.VERSION).append("\r\n");
if (future.getProtocol()!=null)
request.append("Sec-WebSocket-Protocol: ").append(future.getProtocol()).append("\r\n");
if (_future.getProtocol() != null)
request.append("Sec-WebSocket-Protocol: ").append(_future.getProtocol()).append("\r\n");
if (future.getCookies()!=null && future.getCookies().size()>0)
Map<String, String> cookies = _future.getCookies();
if (cookies != null && cookies.size() > 0)
{
for (String cookie : future.getCookies().keySet())
request
.append("Cookie: ")
for (String cookie : cookies.keySet())
request.append("Cookie: ")
.append(QuotedStringTokenizer.quoteIfNeeded(cookie, HttpFields.__COOKIE_DELIM))
.append("=")
.append(QuotedStringTokenizer.quoteIfNeeded(future.getCookies().get(cookie),HttpFields.__COOKIE_DELIM))
.append(QuotedStringTokenizer.quoteIfNeeded(cookies.get(cookie), HttpFields.__COOKIE_DELIM))
.append("\r\n");
}
@ -372,7 +408,11 @@ public class WebSocketClientFactory extends AggregateLifeCycle
}
catch (IOException e)
{
future.handshakeFailed(e);
_future.handshakeFailed(e);
}
finally
{
_handshaken = true;
}
}
@ -380,6 +420,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
{
while (_endp.isOpen() && !_parser.isComplete())
{
if (!_handshaken)
handshake();
if (!_parser.parseAvailable())
{
if (_endp.isInputShutdown())
@ -390,9 +433,13 @@ public class WebSocketClientFactory extends AggregateLifeCycle
if (_error == null)
{
if (_accept == null)
{
_error = "No Sec-WebSocket-Accept";
}
else if (!WebSocketConnectionD13.hashKey(_key).equals(_accept))
{
_error = "Bad Sec-WebSocket-Accept";
}
else
{
Buffer header = _parser.getHeaderBuffer();
@ -443,10 +490,5 @@ public class WebSocketClientFactory extends AggregateLifeCycle
else
_future.handshakeFailed(new EOFException());
}
public String toString()
{
return "HS"+super.toString();
}
}
}

View File

@ -0,0 +1,124 @@
package org.eclipse.jetty.websocket;
import java.io.IOException;
import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ssl.SslSelectChannelConnector;
import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
public class WebSocketOverSSLTest
{
private Server _server;
private int _port;
private WebSocket.Connection _connection;
private void startServer(final WebSocket webSocket) throws Exception
{
_server = new Server();
SslSelectChannelConnector connector = new SslSelectChannelConnector();
_server.addConnector(connector);
SslContextFactory cf = connector.getSslContextFactory();
cf.setKeyStorePath(MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath());
cf.setKeyStorePassword("storepwd");
cf.setKeyManagerPassword("keypwd");
_server.setHandler(new WebSocketHandler()
{
public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol)
{
return webSocket;
}
});
_server.start();
_port = connector.getLocalPort();
}
private void startClient(final WebSocket webSocket) throws Exception
{
Assert.assertTrue(_server.isStarted());
WebSocketClientFactory factory = new WebSocketClientFactory(new QueuedThreadPool(), new ZeroMaskGen());
SslContextFactory cf = factory.getSslContextFactory();
cf.setKeyStorePath(MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath());
cf.setKeyStorePassword("storepwd");
cf.setKeyManagerPassword("keypwd");
factory.start();
WebSocketClient client = new WebSocketClient(factory);
_connection = client.open(new URI("wss://localhost:" + _port), webSocket).get(5, TimeUnit.SECONDS);
}
@After
public void destroy() throws Exception
{
if (_connection != null)
_connection.close();
if (_server != null)
{
_server.stop();
_server.join();
}
}
@Test
public void testWebSocketOverSSL() throws Exception
{
final String message = "message";
final CountDownLatch serverLatch = new CountDownLatch(1);
startServer(new WebSocket.OnTextMessage()
{
private Connection connection;
public void onOpen(Connection connection)
{
this.connection = connection;
}
public void onMessage(String data)
{
try
{
Assert.assertEquals(message, data);
connection.sendMessage(data);
serverLatch.countDown();
}
catch (IOException x)
{
x.printStackTrace();
}
}
public void onClose(int closeCode, String message)
{
}
});
final CountDownLatch clientLatch = new CountDownLatch(1);
startClient(new WebSocket.OnTextMessage()
{
public void onOpen(Connection connection)
{
}
public void onMessage(String data)
{
Assert.assertEquals(message, data);
clientLatch.countDown();
}
public void onClose(int closeCode, String message)
{
}
});
_connection.sendMessage(message);
Assert.assertTrue(serverLatch.await(5, TimeUnit.SECONDS));
Assert.assertTrue(clientLatch.await(5, TimeUnit.SECONDS));
}
}

Binary file not shown.