mirror of https://github.com/apache/activemq.git
WSServlet for websockets will attempt to detect the subprotocol requested and respond with the appropriate one. Currently the protocols loaded are what stomp.js use for stomp (v11.stomp and v12.stomp). If a protocol can't be found then a default will be returned, either "stomp" or "mqtt", which is the same behavior before this patch. This will make it a bit easier to use stomp over websockets out of the box as stomp.js will work by default.
This commit is contained in:
parent
66c348b1b8
commit
913f64476b
|
@ -55,6 +55,10 @@ public class StompWSConnection extends WebSocketAdapter implements WebSocketList
|
|||
}
|
||||
}
|
||||
|
||||
protected Session getConnection() {
|
||||
return connection;
|
||||
}
|
||||
|
||||
//---- Send methods ------------------------------------------------------//
|
||||
|
||||
public synchronized void sendRawFrame(String rawFrame) throws Exception {
|
||||
|
|
|
@ -18,6 +18,12 @@
|
|||
package org.apache.activemq.transport.ws.jetty9;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
|
@ -42,6 +48,19 @@ public class WSServlet extends WebSocketServlet {
|
|||
|
||||
private TransportAcceptListener listener;
|
||||
|
||||
private final static Map<String, Integer> stompProtocols = new ConcurrentHashMap<> ();
|
||||
private final static Map<String, Integer> mqttProtocols = new ConcurrentHashMap<> ();
|
||||
|
||||
static {
|
||||
stompProtocols.put("v12.stomp", 3);
|
||||
stompProtocols.put("v11.stomp", 2);
|
||||
stompProtocols.put("v10.stomp", 1);
|
||||
stompProtocols.put("stomp", 0);
|
||||
|
||||
mqttProtocols.put("mqttv3.1", 1);
|
||||
mqttProtocols.put("mqtt", 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void init() throws ServletException {
|
||||
super.init();
|
||||
|
@ -70,16 +89,51 @@ public class WSServlet extends WebSocketServlet {
|
|||
}
|
||||
if (isMqtt) {
|
||||
socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
|
||||
resp.setAcceptedSubProtocol("mqtt");
|
||||
resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols,req.getSubProtocols(), "mqtt"));
|
||||
((MQTTSocket)socket).setPeerCertificates(req.getCertificates());
|
||||
} else {
|
||||
socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
|
||||
((StompSocket)socket).setCertificates(req.getCertificates());
|
||||
resp.setAcceptedSubProtocol("stomp");
|
||||
resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols,req.getSubProtocols(), "stomp"));
|
||||
}
|
||||
listener.onAccept((Transport) socket);
|
||||
return socket;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private String getAcceptedSubProtocol(final Map<String, Integer> protocols,
|
||||
List<String> subProtocols, String defaultProtocol) {
|
||||
List<SubProtocol> matchedProtocols = new ArrayList<>();
|
||||
if (subProtocols != null && subProtocols.size() > 0) {
|
||||
//detect which subprotocols match accepted protocols and add to the list
|
||||
for (String subProtocol : subProtocols) {
|
||||
Integer priority = protocols.get(subProtocol);
|
||||
if(subProtocol != null && priority != null) {
|
||||
//only insert if both subProtocol and priority are not null
|
||||
matchedProtocols.add(new SubProtocol(subProtocol, priority));
|
||||
}
|
||||
}
|
||||
//sort the list by priority
|
||||
if (matchedProtocols.size() > 0) {
|
||||
Collections.sort(matchedProtocols, new Comparator<SubProtocol>() {
|
||||
@Override
|
||||
public int compare(SubProtocol s1, SubProtocol s2) {
|
||||
return s2.priority.compareTo(s1.priority);
|
||||
}
|
||||
});
|
||||
return matchedProtocols.get(0).protocol;
|
||||
}
|
||||
}
|
||||
return defaultProtocol;
|
||||
}
|
||||
|
||||
private class SubProtocol {
|
||||
private String protocol;
|
||||
private Integer priority;
|
||||
public SubProtocol(String protocol, Integer priority) {
|
||||
this.protocol = protocol;
|
||||
this.priority = priority;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -74,6 +74,10 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
|
|||
}
|
||||
}
|
||||
|
||||
protected Session getConnection() {
|
||||
return connection;
|
||||
}
|
||||
|
||||
//----- Connection and Disconnection methods -----------------------------//
|
||||
|
||||
public void connect() throws Exception {
|
||||
|
|
|
@ -41,17 +41,13 @@ public class MQTTWSConnectionTimeoutTest extends WSTransportTestSupport {
|
|||
super.setUp();
|
||||
wsMQTTConnection = new MQTTWSConnection();
|
||||
|
||||
// WebSocketClientFactory clientFactory = new WebSocketClientFactory();
|
||||
//clientFactory.start();
|
||||
|
||||
wsClient = new WebSocketClient();
|
||||
wsClient.start();
|
||||
|
||||
ClientUpgradeRequest request = new ClientUpgradeRequest();
|
||||
request.setSubProtocols("mqtt");
|
||||
request.setSubProtocols("mqttv3.1");
|
||||
|
||||
wsClient.connect(wsMQTTConnection, wsConnectUri, request);
|
||||
//wsClient.setProtocol("mqttv3.1");
|
||||
|
||||
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
|
||||
throw new IOException("Could not connect to MQTT WS endpoint");
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.activemq.transport.ws;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import org.eclipse.jetty.util.ssl.SslContextFactory;
|
||||
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
public class MQTTWSSubProtocolTest extends WSTransportTestSupport {
|
||||
|
||||
protected WebSocketClient wsClient;
|
||||
protected MQTTWSConnection wsMQTTConnection;
|
||||
protected ClientUpgradeRequest request;
|
||||
|
||||
@Override
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
|
||||
wsClient = new WebSocketClient(new SslContextFactory(true));
|
||||
wsClient.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
@After
|
||||
public void tearDown() throws Exception {
|
||||
if (wsMQTTConnection != null) {
|
||||
wsMQTTConnection.close();
|
||||
wsMQTTConnection = null;
|
||||
wsClient = null;
|
||||
}
|
||||
|
||||
super.tearDown();
|
||||
}
|
||||
|
||||
@Test(timeout = 60000)
|
||||
public void testConnectv31() throws Exception {
|
||||
connect("mqttv3.1");
|
||||
wsMQTTConnection.connect();
|
||||
assertEquals("mqttv3.1", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
|
||||
}
|
||||
|
||||
@Test(timeout = 60000)
|
||||
public void testConnectMqtt() throws Exception {
|
||||
connect("mqtt");
|
||||
wsMQTTConnection.connect();
|
||||
assertEquals("mqtt", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
|
||||
}
|
||||
|
||||
@Test(timeout = 60000)
|
||||
public void testConnectMultiple() throws Exception {
|
||||
connect("mqtt,mqttv3.1");
|
||||
wsMQTTConnection.connect();
|
||||
assertEquals("mqttv3.1", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
|
||||
}
|
||||
|
||||
protected void connect(String subProtocol) throws Exception {
|
||||
request = new ClientUpgradeRequest();
|
||||
if (subProtocol != null) {
|
||||
request.setSubProtocols(subProtocol);
|
||||
}
|
||||
|
||||
wsMQTTConnection = new MQTTWSConnection();
|
||||
|
||||
wsClient.connect(wsMQTTConnection, wsConnectUri, request);
|
||||
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
|
||||
throw new IOException("Could not connect to MQTT WS endpoint");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.activemq.transport.ws;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import org.apache.activemq.transport.stomp.Stomp;
|
||||
import org.eclipse.jetty.util.ssl.SslContextFactory;
|
||||
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Test STOMP sub protocol detection.
|
||||
*/
|
||||
public class StompWSSubProtocolTest extends WSTransportTestSupport {
|
||||
|
||||
protected WebSocketClient wsClient;
|
||||
protected StompWSConnection wsStompConnection;
|
||||
|
||||
@Override
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
wsStompConnection = new StompWSConnection();
|
||||
}
|
||||
|
||||
@Override
|
||||
@After
|
||||
public void tearDown() throws Exception {
|
||||
if (wsStompConnection != null) {
|
||||
wsStompConnection.close();
|
||||
wsStompConnection = null;
|
||||
wsClient = null;
|
||||
}
|
||||
|
||||
super.tearDown();
|
||||
}
|
||||
|
||||
@Test(timeout = 60000)
|
||||
public void testConnectV12() throws Exception {
|
||||
connect("v12.stomp");
|
||||
|
||||
String connectFrame = "STOMP\n" +
|
||||
"accept-version:1.2\n" +
|
||||
"host:localhost\n" +
|
||||
"\n" + Stomp.NULL;
|
||||
|
||||
wsStompConnection.sendRawFrame(connectFrame);
|
||||
|
||||
assertSubProtocol("v12.stomp");
|
||||
}
|
||||
|
||||
@Test(timeout = 60000)
|
||||
public void testConnectV11() throws Exception {
|
||||
connect("v11.stomp");
|
||||
|
||||
String connectFrame = "STOMP\n" +
|
||||
"accept-version:1.2\n" +
|
||||
"host:localhost\n" +
|
||||
"\n" + Stomp.NULL;
|
||||
|
||||
wsStompConnection.sendRawFrame(connectFrame);
|
||||
|
||||
assertSubProtocol("v11.stomp");
|
||||
}
|
||||
|
||||
@Test(timeout = 60000)
|
||||
public void testConnectV10() throws Exception {
|
||||
connect("v10.stomp");
|
||||
|
||||
String connectFrame = "STOMP\n" +
|
||||
"accept-version:1.2\n" +
|
||||
"host:localhost\n" +
|
||||
"\n" + Stomp.NULL;
|
||||
|
||||
wsStompConnection.sendRawFrame(connectFrame);
|
||||
|
||||
assertSubProtocol("v10.stomp");
|
||||
}
|
||||
|
||||
@Test(timeout = 60000)
|
||||
public void testConnectNone() throws Exception {
|
||||
|
||||
connect(null);
|
||||
|
||||
String connectFrame = "STOMP\n" +
|
||||
"accept-version:1.2\n" +
|
||||
"host:localhost\n" +
|
||||
"\n" + Stomp.NULL;
|
||||
|
||||
wsStompConnection.sendRawFrame(connectFrame);
|
||||
|
||||
assertSubProtocol("stomp");
|
||||
}
|
||||
|
||||
@Test(timeout = 60000)
|
||||
public void testConnectMultiple() throws Exception {
|
||||
|
||||
connect("v10.stomp,v11.stomp");
|
||||
|
||||
String connectFrame = "STOMP\n" +
|
||||
"accept-version:1.2\n" +
|
||||
"host:localhost\n" +
|
||||
"\n" + Stomp.NULL;
|
||||
|
||||
wsStompConnection.sendRawFrame(connectFrame);
|
||||
|
||||
assertSubProtocol("v11.stomp");
|
||||
}
|
||||
|
||||
@Test(timeout = 60000)
|
||||
public void testConnectInvalid() throws Exception {
|
||||
connect("invalid");
|
||||
|
||||
String connectFrame = "STOMP\n" +
|
||||
"accept-version:1.2\n" +
|
||||
"host:localhost\n" +
|
||||
"\n" + Stomp.NULL;
|
||||
|
||||
wsStompConnection.sendRawFrame(connectFrame);
|
||||
|
||||
assertSubProtocol("stomp");
|
||||
}
|
||||
|
||||
protected void connect(String subProtocol) throws Exception{
|
||||
ClientUpgradeRequest request = new ClientUpgradeRequest();
|
||||
if (subProtocol != null) {
|
||||
request.setSubProtocols(subProtocol);
|
||||
}
|
||||
|
||||
wsClient = new WebSocketClient(new SslContextFactory(true));
|
||||
wsClient.start();
|
||||
|
||||
wsClient.connect(wsStompConnection, wsConnectUri, request);
|
||||
if (!wsStompConnection.awaitConnection(30, TimeUnit.SECONDS)) {
|
||||
throw new IOException("Could not connect to STOMP WS endpoint");
|
||||
}
|
||||
}
|
||||
|
||||
protected void assertSubProtocol(String subProtocol) throws Exception {
|
||||
String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS);
|
||||
assertNotNull(incoming);
|
||||
assertTrue(incoming.startsWith("CONNECTED"));
|
||||
assertEquals(subProtocol, wsStompConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
|
||||
}
|
||||
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
*/
|
||||
package org.apache.activemq.transport.ws;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
|
@ -28,6 +29,7 @@ import org.apache.activemq.transport.stomp.Stomp;
|
|||
import org.apache.activemq.transport.stomp.StompFrame;
|
||||
import org.apache.activemq.util.Wait;
|
||||
import org.eclipse.jetty.util.ssl.SslContextFactory;
|
||||
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
@ -51,10 +53,14 @@ public class StompWSTransportTest extends WSTransportTestSupport {
|
|||
super.setUp();
|
||||
wsStompConnection = new StompWSConnection();
|
||||
|
||||
|
||||
ClientUpgradeRequest request = new ClientUpgradeRequest();
|
||||
request.setSubProtocols("v11.stomp");
|
||||
|
||||
wsClient = new WebSocketClient(new SslContextFactory(true));
|
||||
wsClient.start();
|
||||
|
||||
wsClient.connect(wsStompConnection, wsConnectUri);
|
||||
wsClient.connect(wsStompConnection, wsConnectUri, request);
|
||||
if (!wsStompConnection.awaitConnection(30, TimeUnit.SECONDS)) {
|
||||
throw new IOException("Could not connect to STOMP WS endpoint");
|
||||
}
|
||||
|
@ -86,6 +92,7 @@ public class StompWSTransportTest extends WSTransportTestSupport {
|
|||
String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS);
|
||||
assertNotNull(incoming);
|
||||
assertTrue(incoming.startsWith("CONNECTED"));
|
||||
assertEquals("v11.stomp", wsStompConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol());
|
||||
|
||||
assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() {
|
||||
|
||||
|
|
Loading…
Reference in New Issue