365750 - Support WebSocket over SSL, aka wss://
This is now implemented, using the new architecture of wrapping the connection with SslConnection. The only refactoring was to avoid that the HTTP handshake was sent from the HandshakeConnection constructor, because at that point the SSL wiring is not ready yet. Now the handshake is sent from handle(), guarded by a boolean variable to sent it once.
This commit is contained in:
parent
ba95a9ba3a
commit
0689e05e9b
|
@ -320,8 +320,6 @@ public class WebSocketClient
|
|||
String scheme=uri.getScheme();
|
||||
if (!("ws".equalsIgnoreCase(scheme) || "wss".equalsIgnoreCase(scheme)))
|
||||
throw new IllegalArgumentException("Bad WebSocket scheme '"+scheme+"'");
|
||||
if ("wss".equalsIgnoreCase(scheme))
|
||||
throw new IOException("wss not supported");
|
||||
|
||||
SocketChannel channel = SocketChannel.open();
|
||||
if (_bindAddress != null)
|
||||
|
|
|
@ -5,7 +5,9 @@ import java.io.IOException;
|
|||
import java.net.ProtocolException;
|
||||
import java.nio.channels.SelectionKey;
|
||||
import java.nio.channels.SocketChannel;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import javax.net.ssl.SSLEngine;
|
||||
|
||||
import org.eclipse.jetty.http.HttpFields;
|
||||
import org.eclipse.jetty.http.HttpParser;
|
||||
|
@ -20,15 +22,15 @@ import org.eclipse.jetty.io.SimpleBuffers;
|
|||
import org.eclipse.jetty.io.nio.AsyncConnection;
|
||||
import org.eclipse.jetty.io.nio.SelectChannelEndPoint;
|
||||
import org.eclipse.jetty.io.nio.SelectorManager;
|
||||
import org.eclipse.jetty.io.nio.SslConnection;
|
||||
import org.eclipse.jetty.util.B64Code;
|
||||
import org.eclipse.jetty.util.QuotedStringTokenizer;
|
||||
import org.eclipse.jetty.util.component.AggregateLifeCycle;
|
||||
import org.eclipse.jetty.util.component.LifeCycle;
|
||||
import org.eclipse.jetty.util.log.Logger;
|
||||
import org.eclipse.jetty.util.ssl.SslContextFactory;
|
||||
import org.eclipse.jetty.util.thread.QueuedThreadPool;
|
||||
import org.eclipse.jetty.util.thread.ThreadPool;
|
||||
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
/**
|
||||
* <p>WebSocketClientFactory contains the common components needed by multiple {@link WebSocketClient} instances
|
||||
|
@ -36,14 +38,15 @@ import org.eclipse.jetty.util.thread.ThreadPool;
|
|||
* <p>WebSocketClients with different configurations should share the same factory to avoid to waste resources.</p>
|
||||
* <p>If a ThreadPool or MaskGen is passed in the constructor, then it is not added with {@link AggregateLifeCycle#addBean(Object)},
|
||||
* so it's lifecycle must be controlled externally.
|
||||
*
|
||||
* @see WebSocketClient
|
||||
*/
|
||||
public class WebSocketClientFactory extends AggregateLifeCycle
|
||||
{
|
||||
private final static Logger __log = org.eclipse.jetty.util.log.Log.getLogger(WebSocketClientFactory.class.getName());
|
||||
private final static Random __random = new Random();
|
||||
private final static ByteArrayBuffer __ACCEPT = new ByteArrayBuffer.CaseInsensitive("Sec-WebSocket-Accept");
|
||||
|
||||
private SslContextFactory _sslContextFactory = new SslContextFactory();
|
||||
private final ThreadPool _threadPool;
|
||||
private final WebSocketClientSelector _selector;
|
||||
private MaskGen _maskGen;
|
||||
|
@ -55,36 +58,37 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
*/
|
||||
public WebSocketClientFactory()
|
||||
{
|
||||
_threadPool=new QueuedThreadPool();
|
||||
addBean(_threadPool);
|
||||
_buffers=new WebSocketBuffers(8*1024);
|
||||
addBean(_buffers);
|
||||
_maskGen=new RandomMaskGen();
|
||||
addBean(_maskGen);
|
||||
_selector=new WebSocketClientSelector();
|
||||
addBean(_selector);
|
||||
this(new QueuedThreadPool());
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
/**
|
||||
* <p>Creates a WebSocketClientFactory with the given ThreadPool and the default configuration.</p>
|
||||
*
|
||||
* @param threadPool the ThreadPool instance to use
|
||||
*/
|
||||
public WebSocketClientFactory(ThreadPool threadPool)
|
||||
{
|
||||
_threadPool=threadPool;
|
||||
addBean(threadPool);
|
||||
_buffers=new WebSocketBuffers(8*1024);
|
||||
addBean(_buffers);
|
||||
_maskGen=new RandomMaskGen();
|
||||
addBean(_maskGen);
|
||||
_selector=new WebSocketClientSelector();
|
||||
addBean(_selector);
|
||||
this(threadPool, new RandomMaskGen());
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
/**
|
||||
* <p>Creates a WebSocketClientFactory with the given ThreadPool and the given MaskGen.</p>
|
||||
*
|
||||
* @param threadPool the ThreadPool instance to use
|
||||
* @param maskGen the MaskGen instance to use
|
||||
*/
|
||||
public WebSocketClientFactory(ThreadPool threadPool, MaskGen maskGen)
|
||||
{
|
||||
this(threadPool, maskGen, 8192);
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
|
||||
/**
|
||||
* <p>Creates a WebSocketClientFactory with the specified configuration.</p>
|
||||
*
|
||||
* @param threadPool the ThreadPool instance to use
|
||||
* @param maskGen the mask generator to use
|
||||
* @param bufferSize the read buffer size
|
||||
|
@ -96,13 +100,25 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
_buffers = new WebSocketBuffers(bufferSize);
|
||||
addBean(_buffers);
|
||||
_maskGen = maskGen;
|
||||
addBean(_maskGen);
|
||||
_selector = new WebSocketClientSelector();
|
||||
addBean(_selector);
|
||||
addBean(_sslContextFactory);
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
/**
|
||||
* @return the SslContextFactory used to configure SSL parameters
|
||||
*/
|
||||
public SslContextFactory getSslContextFactory()
|
||||
{
|
||||
return _sslContextFactory;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
/**
|
||||
* Get the selectorManager. Used to configure the manager.
|
||||
*
|
||||
* @return The {@link SelectorManager} instance.
|
||||
*/
|
||||
public SelectorManager getSelectorManager()
|
||||
|
@ -111,8 +127,10 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
/** Get the ThreadPool.
|
||||
/**
|
||||
* Get the ThreadPool.
|
||||
* Used to set/query the thread pool configuration.
|
||||
*
|
||||
* @return The {@link ThreadPool}
|
||||
*/
|
||||
public ThreadPool getThreadPool()
|
||||
|
@ -139,9 +157,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
{
|
||||
if (isRunning())
|
||||
throw new IllegalStateException(getState());
|
||||
if (removeBean(_maskGen))
|
||||
addBean(maskGen);
|
||||
removeBean(_maskGen);
|
||||
_maskGen = maskGen;
|
||||
addBean(maskGen);
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
|
@ -179,24 +197,28 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
return new WebSocketClient(this);
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
@Override
|
||||
protected void doStart() throws Exception
|
||||
protected SSLEngine newSslEngine(SocketChannel channel) throws IOException
|
||||
{
|
||||
super.doStart();
|
||||
if (getThreadPool() instanceof LifeCycle && !((LifeCycle)getThreadPool()).isStarted())
|
||||
((LifeCycle)getThreadPool()).start();
|
||||
SSLEngine sslEngine;
|
||||
if (channel != null)
|
||||
{
|
||||
String peerHost = channel.socket().getInetAddress().getHostAddress();
|
||||
int peerPort = channel.socket().getPort();
|
||||
sslEngine = _sslContextFactory.newSslEngine(peerHost, peerPort);
|
||||
}
|
||||
else
|
||||
{
|
||||
sslEngine = _sslContextFactory.newSslEngine();
|
||||
}
|
||||
sslEngine.setUseClientMode(true);
|
||||
sslEngine.beginHandshake();
|
||||
|
||||
return sslEngine;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
@Override
|
||||
protected void doStop() throws Exception
|
||||
{
|
||||
super.doStop();
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
/** WebSocket Client Selector Manager
|
||||
/**
|
||||
* WebSocket Client Selector Manager
|
||||
*/
|
||||
class WebSocketClientSelector extends SelectorManager
|
||||
{
|
||||
|
@ -209,9 +231,26 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
@Override
|
||||
protected SelectChannelEndPoint newEndPoint(SocketChannel channel, SelectSet selectSet, final SelectionKey key) throws IOException
|
||||
{
|
||||
SelectChannelEndPoint endp= new SelectChannelEndPoint(channel,selectSet,key,channel.socket().getSoTimeout());
|
||||
endp.setConnection(selectSet.getManager().newConnection(channel,endp, key.attachment()));
|
||||
return endp;
|
||||
WebSocketClient.WebSocketFuture holder = (WebSocketClient.WebSocketFuture)key.attachment();
|
||||
int maxIdleTime = holder.getMaxIdleTime();
|
||||
if (maxIdleTime < 0)
|
||||
maxIdleTime = (int)getMaxIdleTime();
|
||||
SelectChannelEndPoint result = new SelectChannelEndPoint(channel, selectSet, key, maxIdleTime);
|
||||
AsyncEndPoint endPoint = result;
|
||||
|
||||
// Detect if it is SSL, and wrap the connection if so
|
||||
if ("wss".equals(holder.getURI().getScheme()))
|
||||
{
|
||||
SSLEngine sslEngine = newSslEngine(channel);
|
||||
SslConnection sslConnection = new SslConnection(sslEngine, endPoint);
|
||||
endPoint.setConnection(sslConnection);
|
||||
endPoint = sslConnection.getSslEndPoint();
|
||||
}
|
||||
|
||||
AsyncConnection connection = selectSet.getManager().newConnection(channel, endPoint, holder);
|
||||
endPoint.setConnection(connection);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -254,9 +293,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------------------------------------------ */
|
||||
/** Handshake Connection.
|
||||
/**
|
||||
* Handshake Connection.
|
||||
* Handles the connection until the handshake succeeds or fails.
|
||||
*/
|
||||
class HandshakeConnection extends AbstractConnection implements AsyncConnection
|
||||
|
@ -267,6 +306,7 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
private final HttpParser _parser;
|
||||
private String _accept;
|
||||
private String _error;
|
||||
private boolean _handshaken;
|
||||
|
||||
public HandshakeConnection(AsyncEndPoint endpoint, WebSocketClient.WebSocketFuture future)
|
||||
{
|
||||
|
@ -275,14 +315,11 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
_future = future;
|
||||
|
||||
byte[] bytes = new byte[16];
|
||||
__random.nextBytes(bytes);
|
||||
new Random().nextBytes(bytes);
|
||||
_key = new String(B64Code.encode(bytes));
|
||||
|
||||
|
||||
Buffers buffers = new SimpleBuffers(_buffers.getBuffer(), null);
|
||||
_parser=new HttpParser(buffers,_endp,
|
||||
|
||||
new HttpParser.EventHandler()
|
||||
_parser = new HttpParser(buffers, _endp, new HttpParser.EventHandler()
|
||||
{
|
||||
@Override
|
||||
public void startResponse(Buffer version, int status, Buffer reason) throws IOException
|
||||
|
@ -317,24 +354,23 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
_endp.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void handshake()
|
||||
{
|
||||
String path = _future.getURI().getPath();
|
||||
if (path == null || path.length() == 0)
|
||||
{
|
||||
path = "/";
|
||||
}
|
||||
|
||||
if (_future.getURI().getRawQuery() != null)
|
||||
{
|
||||
path += "?" + _future.getURI().getRawQuery();
|
||||
}
|
||||
|
||||
String origin = future.getOrigin();
|
||||
String origin = _future.getOrigin();
|
||||
|
||||
StringBuilder request = new StringBuilder(512);
|
||||
request
|
||||
.append("GET ").append(path).append(" HTTP/1.1\r\n")
|
||||
.append("Host: ").append(future.getURI().getHost()).append(":").append(_future.getURI().getPort()).append("\r\n")
|
||||
request.append("GET ").append(path).append(" HTTP/1.1\r\n")
|
||||
.append("Host: ").append(_future.getURI().getHost()).append(":")
|
||||
.append(_future.getURI().getPort()).append("\r\n")
|
||||
.append("Upgrade: websocket\r\n")
|
||||
.append("Connection: Upgrade\r\n")
|
||||
.append("Sec-WebSocket-Key: ")
|
||||
|
@ -345,17 +381,17 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
|
||||
request.append("Sec-WebSocket-Version: ").append(WebSocketConnectionD13.VERSION).append("\r\n");
|
||||
|
||||
if (future.getProtocol()!=null)
|
||||
request.append("Sec-WebSocket-Protocol: ").append(future.getProtocol()).append("\r\n");
|
||||
if (_future.getProtocol() != null)
|
||||
request.append("Sec-WebSocket-Protocol: ").append(_future.getProtocol()).append("\r\n");
|
||||
|
||||
if (future.getCookies()!=null && future.getCookies().size()>0)
|
||||
Map<String, String> cookies = _future.getCookies();
|
||||
if (cookies != null && cookies.size() > 0)
|
||||
{
|
||||
for (String cookie : future.getCookies().keySet())
|
||||
request
|
||||
.append("Cookie: ")
|
||||
for (String cookie : cookies.keySet())
|
||||
request.append("Cookie: ")
|
||||
.append(QuotedStringTokenizer.quoteIfNeeded(cookie, HttpFields.__COOKIE_DELIM))
|
||||
.append("=")
|
||||
.append(QuotedStringTokenizer.quoteIfNeeded(future.getCookies().get(cookie),HttpFields.__COOKIE_DELIM))
|
||||
.append(QuotedStringTokenizer.quoteIfNeeded(cookies.get(cookie), HttpFields.__COOKIE_DELIM))
|
||||
.append("\r\n");
|
||||
}
|
||||
|
||||
|
@ -372,7 +408,11 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
}
|
||||
catch (IOException e)
|
||||
{
|
||||
future.handshakeFailed(e);
|
||||
_future.handshakeFailed(e);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_handshaken = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -380,6 +420,9 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
{
|
||||
while (_endp.isOpen() && !_parser.isComplete())
|
||||
{
|
||||
if (!_handshaken)
|
||||
handshake();
|
||||
|
||||
if (!_parser.parseAvailable())
|
||||
{
|
||||
if (_endp.isInputShutdown())
|
||||
|
@ -390,9 +433,13 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
if (_error == null)
|
||||
{
|
||||
if (_accept == null)
|
||||
{
|
||||
_error = "No Sec-WebSocket-Accept";
|
||||
}
|
||||
else if (!WebSocketConnectionD13.hashKey(_key).equals(_accept))
|
||||
{
|
||||
_error = "Bad Sec-WebSocket-Accept";
|
||||
}
|
||||
else
|
||||
{
|
||||
Buffer header = _parser.getHeaderBuffer();
|
||||
|
@ -443,10 +490,5 @@ public class WebSocketClientFactory extends AggregateLifeCycle
|
|||
else
|
||||
_future.handshakeFailed(new EOFException());
|
||||
}
|
||||
|
||||
public String toString()
|
||||
{
|
||||
return "HS"+super.toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
package org.eclipse.jetty.websocket;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
||||
import org.eclipse.jetty.server.Server;
|
||||
import org.eclipse.jetty.server.ssl.SslSelectChannelConnector;
|
||||
import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
|
||||
import org.eclipse.jetty.util.ssl.SslContextFactory;
|
||||
import org.eclipse.jetty.util.thread.QueuedThreadPool;
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class WebSocketOverSSLTest
|
||||
{
|
||||
private Server _server;
|
||||
private int _port;
|
||||
private WebSocket.Connection _connection;
|
||||
|
||||
private void startServer(final WebSocket webSocket) throws Exception
|
||||
{
|
||||
_server = new Server();
|
||||
SslSelectChannelConnector connector = new SslSelectChannelConnector();
|
||||
_server.addConnector(connector);
|
||||
SslContextFactory cf = connector.getSslContextFactory();
|
||||
cf.setKeyStorePath(MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath());
|
||||
cf.setKeyStorePassword("storepwd");
|
||||
cf.setKeyManagerPassword("keypwd");
|
||||
_server.setHandler(new WebSocketHandler()
|
||||
{
|
||||
public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol)
|
||||
{
|
||||
return webSocket;
|
||||
}
|
||||
});
|
||||
_server.start();
|
||||
_port = connector.getLocalPort();
|
||||
}
|
||||
|
||||
private void startClient(final WebSocket webSocket) throws Exception
|
||||
{
|
||||
Assert.assertTrue(_server.isStarted());
|
||||
|
||||
WebSocketClientFactory factory = new WebSocketClientFactory(new QueuedThreadPool(), new ZeroMaskGen());
|
||||
SslContextFactory cf = factory.getSslContextFactory();
|
||||
cf.setKeyStorePath(MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath());
|
||||
cf.setKeyStorePassword("storepwd");
|
||||
cf.setKeyManagerPassword("keypwd");
|
||||
factory.start();
|
||||
WebSocketClient client = new WebSocketClient(factory);
|
||||
_connection = client.open(new URI("wss://localhost:" + _port), webSocket).get(5, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
@After
|
||||
public void destroy() throws Exception
|
||||
{
|
||||
if (_connection != null)
|
||||
_connection.close();
|
||||
if (_server != null)
|
||||
{
|
||||
_server.stop();
|
||||
_server.join();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWebSocketOverSSL() throws Exception
|
||||
{
|
||||
final String message = "message";
|
||||
final CountDownLatch serverLatch = new CountDownLatch(1);
|
||||
startServer(new WebSocket.OnTextMessage()
|
||||
{
|
||||
private Connection connection;
|
||||
|
||||
public void onOpen(Connection connection)
|
||||
{
|
||||
this.connection = connection;
|
||||
}
|
||||
|
||||
public void onMessage(String data)
|
||||
{
|
||||
try
|
||||
{
|
||||
Assert.assertEquals(message, data);
|
||||
connection.sendMessage(data);
|
||||
serverLatch.countDown();
|
||||
}
|
||||
catch (IOException x)
|
||||
{
|
||||
x.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public void onClose(int closeCode, String message)
|
||||
{
|
||||
}
|
||||
});
|
||||
final CountDownLatch clientLatch = new CountDownLatch(1);
|
||||
startClient(new WebSocket.OnTextMessage()
|
||||
{
|
||||
public void onOpen(Connection connection)
|
||||
{
|
||||
}
|
||||
|
||||
public void onMessage(String data)
|
||||
{
|
||||
Assert.assertEquals(message, data);
|
||||
clientLatch.countDown();
|
||||
}
|
||||
|
||||
public void onClose(int closeCode, String message)
|
||||
{
|
||||
}
|
||||
});
|
||||
_connection.sendMessage(message);
|
||||
|
||||
Assert.assertTrue(serverLatch.await(5, TimeUnit.SECONDS));
|
||||
Assert.assertTrue(clientLatch.await(5, TimeUnit.SECONDS));
|
||||
}
|
||||
}
|
Binary file not shown.
Loading…
Reference in New Issue