WSServlet for websockets will attempt to detect the subprotocol
requested and respond with the appropriate one.  Currently the protocols
loaded are what stomp.js use for stomp (v11.stomp and v12.stomp).
If a protocol can't be found then a default will be returned, either
"stomp" or "mqtt", which is the same behavior before this patch.

This will make it a bit easier to use stomp over websockets out of the
box as stomp.js will work by default.

(cherry picked from commit 913f64476b)
This commit is contained in:
Christopher L. Shannon (cshannon) 2015-12-07 15:29:13 +00:00
parent 11ebaccd64
commit a4845253d0
7 changed files with 335 additions and 8 deletions

View File

@ -55,6 +55,10 @@ public class StompWSConnection extends WebSocketAdapter implements WebSocketList
}
}
protected Session getConnection() {
return connection;
}
//---- Send methods ------------------------------------------------------//
public synchronized void sendRawFrame(String rawFrame) throws Exception {

View File

@ -18,6 +18,12 @@
package org.apache.activemq.transport.ws.jetty9;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
@ -42,6 +48,19 @@ public class WSServlet extends WebSocketServlet {
private TransportAcceptListener listener;
private final static Map<String, Integer> stompProtocols = new ConcurrentHashMap<> ();
private final static Map<String, Integer> mqttProtocols = new ConcurrentHashMap<> ();
static {
stompProtocols.put("v12.stomp", 3);
stompProtocols.put("v11.stomp", 2);
stompProtocols.put("v10.stomp", 1);
stompProtocols.put("stomp", 0);
mqttProtocols.put("mqttv3.1", 1);
mqttProtocols.put("mqtt", 0);
}
@Override
public void init() throws ServletException {
super.init();
@ -70,16 +89,51 @@ public class WSServlet extends WebSocketServlet {
}
if (isMqtt) {
socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
resp.setAcceptedSubProtocol("mqtt");
resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols,req.getSubProtocols(), "mqtt"));
((MQTTSocket)socket).setPeerCertificates(req.getCertificates());
} else {
socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
((StompSocket)socket).setCertificates(req.getCertificates());
resp.setAcceptedSubProtocol("stomp");
resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols,req.getSubProtocols(), "stomp"));
}
listener.onAccept((Transport) socket);
return socket;
}
});
}
private String getAcceptedSubProtocol(final Map<String, Integer> protocols,
List<String> subProtocols, String defaultProtocol) {
List<SubProtocol> matchedProtocols = new ArrayList<>();
if (subProtocols != null && subProtocols.size() > 0) {
//detect which subprotocols match accepted protocols and add to the list
for (String subProtocol : subProtocols) {
Integer priority = protocols.get(subProtocol);
if(subProtocol != null && priority != null) {
//only insert if both subProtocol and priority are not null
matchedProtocols.add(new SubProtocol(subProtocol, priority));
}
}
//sort the list by priority
if (matchedProtocols.size() > 0) {
Collections.sort(matchedProtocols, new Comparator<SubProtocol>() {
@Override
public int compare(SubProtocol s1, SubProtocol s2) {
return s2.priority.compareTo(s1.priority);
}
});
return matchedProtocols.get(0).protocol;
}
}
return defaultProtocol;
}
private class SubProtocol {
private String protocol;
private Integer priority;
public SubProtocol(String protocol, Integer priority) {
this.protocol = protocol;
this.priority = priority;
}
}
}

View File

@ -74,6 +74,10 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
}
}
protected Session getConnection() {
return connection;
}
//----- Connection and Disconnection methods -----------------------------//
public void connect() throws Exception {

View File

@ -41,17 +41,13 @@ public class MQTTWSConnectionTimeoutTest extends WSTransportTestSupport {
super.setUp();
wsMQTTConnection = new MQTTWSConnection();
// WebSocketClientFactory clientFactory = new WebSocketClientFactory();
//clientFactory.start();
wsClient = new WebSocketClient();
wsClient.start();
ClientUpgradeRequest request = new ClientUpgradeRequest();
request.setSubProtocols("mqtt");
request.setSubProtocols("mqttv3.1");
wsClient.connect(wsMQTTConnection, wsConnectUri, request);
//wsClient.setProtocol("mqttv3.1");
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
throw new IOException("Could not connect to MQTT WS endpoint");

View File

@ -0,0 +1,93 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.ws;
import static org.junit.Assert.assertEquals;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
public class MQTTWSSubProtocolTest extends WSTransportTestSupport {
protected WebSocketClient wsClient;
protected MQTTWSConnection wsMQTTConnection;
protected ClientUpgradeRequest request;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
wsClient = new WebSocketClient(new SslContextFactory(true));
wsClient.start();
}
@Override
@After
public void tearDown() throws Exception {
if (wsMQTTConnection != null) {
wsMQTTConnection.close();
wsMQTTConnection = null;
wsClient = null;
}
super.tearDown();
}
@Test(timeout = 60000)
public void testConnectv31() throws Exception {
connect("mqttv3.1");
wsMQTTConnection.connect();
assertEquals("mqttv3.1", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
}
@Test(timeout = 60000)
public void testConnectMqtt() throws Exception {
connect("mqtt");
wsMQTTConnection.connect();
assertEquals("mqtt", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
}
@Test(timeout = 60000)
public void testConnectMultiple() throws Exception {
connect("mqtt,mqttv3.1");
wsMQTTConnection.connect();
assertEquals("mqttv3.1", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
}
protected void connect(String subProtocol) throws Exception {
request = new ClientUpgradeRequest();
if (subProtocol != null) {
request.setSubProtocols(subProtocol);
}
wsMQTTConnection = new MQTTWSConnection();
wsClient.connect(wsMQTTConnection, wsConnectUri, request);
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
throw new IOException("Could not connect to MQTT WS endpoint");
}
}
}

View File

@ -0,0 +1,169 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.ws;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import org.apache.activemq.transport.stomp.Stomp;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
/**
* Test STOMP sub protocol detection.
*/
public class StompWSSubProtocolTest extends WSTransportTestSupport {
protected WebSocketClient wsClient;
protected StompWSConnection wsStompConnection;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
wsStompConnection = new StompWSConnection();
}
@Override
@After
public void tearDown() throws Exception {
if (wsStompConnection != null) {
wsStompConnection.close();
wsStompConnection = null;
wsClient = null;
}
super.tearDown();
}
@Test(timeout = 60000)
public void testConnectV12() throws Exception {
connect("v12.stomp");
String connectFrame = "STOMP\n" +
"accept-version:1.2\n" +
"host:localhost\n" +
"\n" + Stomp.NULL;
wsStompConnection.sendRawFrame(connectFrame);
assertSubProtocol("v12.stomp");
}
@Test(timeout = 60000)
public void testConnectV11() throws Exception {
connect("v11.stomp");
String connectFrame = "STOMP\n" +
"accept-version:1.2\n" +
"host:localhost\n" +
"\n" + Stomp.NULL;
wsStompConnection.sendRawFrame(connectFrame);
assertSubProtocol("v11.stomp");
}
@Test(timeout = 60000)
public void testConnectV10() throws Exception {
connect("v10.stomp");
String connectFrame = "STOMP\n" +
"accept-version:1.2\n" +
"host:localhost\n" +
"\n" + Stomp.NULL;
wsStompConnection.sendRawFrame(connectFrame);
assertSubProtocol("v10.stomp");
}
@Test(timeout = 60000)
public void testConnectNone() throws Exception {
connect(null);
String connectFrame = "STOMP\n" +
"accept-version:1.2\n" +
"host:localhost\n" +
"\n" + Stomp.NULL;
wsStompConnection.sendRawFrame(connectFrame);
assertSubProtocol("stomp");
}
@Test(timeout = 60000)
public void testConnectMultiple() throws Exception {
connect("v10.stomp,v11.stomp");
String connectFrame = "STOMP\n" +
"accept-version:1.2\n" +
"host:localhost\n" +
"\n" + Stomp.NULL;
wsStompConnection.sendRawFrame(connectFrame);
assertSubProtocol("v11.stomp");
}
@Test(timeout = 60000)
public void testConnectInvalid() throws Exception {
connect("invalid");
String connectFrame = "STOMP\n" +
"accept-version:1.2\n" +
"host:localhost\n" +
"\n" + Stomp.NULL;
wsStompConnection.sendRawFrame(connectFrame);
assertSubProtocol("stomp");
}
protected void connect(String subProtocol) throws Exception{
ClientUpgradeRequest request = new ClientUpgradeRequest();
if (subProtocol != null) {
request.setSubProtocols(subProtocol);
}
wsClient = new WebSocketClient(new SslContextFactory(true));
wsClient.start();
wsClient.connect(wsStompConnection, wsConnectUri, request);
if (!wsStompConnection.awaitConnection(30, TimeUnit.SECONDS)) {
throw new IOException("Could not connect to STOMP WS endpoint");
}
}
protected void assertSubProtocol(String subProtocol) throws Exception {
String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS);
assertNotNull(incoming);
assertTrue(incoming.startsWith("CONNECTED"));
assertEquals(subProtocol, wsStompConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
}
}

View File

@ -16,6 +16,7 @@
*/
package org.apache.activemq.transport.ws;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
@ -28,6 +29,7 @@ import org.apache.activemq.transport.stomp.Stomp;
import org.apache.activemq.transport.stomp.StompFrame;
import org.apache.activemq.util.Wait;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.junit.After;
import org.junit.Before;
@ -51,10 +53,14 @@ public class StompWSTransportTest extends WSTransportTestSupport {
super.setUp();
wsStompConnection = new StompWSConnection();
ClientUpgradeRequest request = new ClientUpgradeRequest();
request.setSubProtocols("v11.stomp");
wsClient = new WebSocketClient(new SslContextFactory(true));
wsClient.start();
wsClient.connect(wsStompConnection, wsConnectUri);
wsClient.connect(wsStompConnection, wsConnectUri, request);
if (!wsStompConnection.awaitConnection(30, TimeUnit.SECONDS)) {
throw new IOException("Could not connect to STOMP WS endpoint");
}
@ -86,6 +92,7 @@ public class StompWSTransportTest extends WSTransportTestSupport {
String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS);
assertNotNull(incoming);
assertTrue(incoming.startsWith("CONNECTED"));
assertEquals("v11.stomp", wsStompConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() {