From a4845253d01b2d6a184d5494da82be938e7a4233 Mon Sep 17 00:00:00 2001 From: "Christopher L. Shannon (cshannon)" Date: Mon, 7 Dec 2015 15:29:13 +0000 Subject: [PATCH] https://issues.apache.org/jira/browse/AMQ-6073 WSServlet for websockets will attempt to detect the subprotocol requested and respond with the appropriate one. Currently the protocols loaded are what stomp.js use for stomp (v11.stomp and v12.stomp). If a protocol can't be found then a default will be returned, either "stomp" or "mqtt", which is the same behavior before this patch. This will make it a bit easier to use stomp over websockets out of the box as stomp.js will work by default. (cherry picked from commit 913f64476b66c452fa03cb4fb8c09d831825bca5) --- .../transport/ws/StompWSConnection.java | 4 + .../transport/ws/jetty9/WSServlet.java | 58 +++++- .../transport/ws/MQTTWSConnection.java | 4 + .../ws/MQTTWSConnectionTimeoutTest.java | 6 +- .../transport/ws/MQTTWSSubProtocolTest.java | 93 ++++++++++ .../transport/ws/StompWSSubProtocolTest.java | 169 ++++++++++++++++++ .../transport/ws/StompWSTransportTest.java | 9 +- 7 files changed, 335 insertions(+), 8 deletions(-) create mode 100644 activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSubProtocolTest.java create mode 100644 activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSSubProtocolTest.java diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java index af7f86058a..3ded98f1f6 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java @@ -55,6 +55,10 @@ public class StompWSConnection extends WebSocketAdapter implements WebSocketList } } + protected Session getConnection() { + return connection; + } + //---- Send methods ------------------------------------------------------// public synchronized void sendRawFrame(String rawFrame) throws Exception { diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java index 981c2ff80f..cbcdc611ce 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java @@ -18,6 +18,12 @@ package org.apache.activemq.transport.ws.jetty9; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -42,6 +48,19 @@ public class WSServlet extends WebSocketServlet { private TransportAcceptListener listener; + private final static Map stompProtocols = new ConcurrentHashMap<> (); + private final static Map mqttProtocols = new ConcurrentHashMap<> (); + + static { + stompProtocols.put("v12.stomp", 3); + stompProtocols.put("v11.stomp", 2); + stompProtocols.put("v10.stomp", 1); + stompProtocols.put("stomp", 0); + + mqttProtocols.put("mqttv3.1", 1); + mqttProtocols.put("mqtt", 0); + } + @Override public void init() throws ServletException { super.init(); @@ -70,16 +89,51 @@ public class WSServlet extends WebSocketServlet { } if (isMqtt) { socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); - resp.setAcceptedSubProtocol("mqtt"); + resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols,req.getSubProtocols(), "mqtt")); ((MQTTSocket)socket).setPeerCertificates(req.getCertificates()); } else { socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); ((StompSocket)socket).setCertificates(req.getCertificates()); - resp.setAcceptedSubProtocol("stomp"); + resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols,req.getSubProtocols(), "stomp")); } listener.onAccept((Transport) socket); return socket; } }); } + + private String getAcceptedSubProtocol(final Map protocols, + List subProtocols, String defaultProtocol) { + List matchedProtocols = new ArrayList<>(); + if (subProtocols != null && subProtocols.size() > 0) { + //detect which subprotocols match accepted protocols and add to the list + for (String subProtocol : subProtocols) { + Integer priority = protocols.get(subProtocol); + if(subProtocol != null && priority != null) { + //only insert if both subProtocol and priority are not null + matchedProtocols.add(new SubProtocol(subProtocol, priority)); + } + } + //sort the list by priority + if (matchedProtocols.size() > 0) { + Collections.sort(matchedProtocols, new Comparator() { + @Override + public int compare(SubProtocol s1, SubProtocol s2) { + return s2.priority.compareTo(s1.priority); + } + }); + return matchedProtocols.get(0).protocol; + } + } + return defaultProtocol; + } + + private class SubProtocol { + private String protocol; + private Integer priority; + public SubProtocol(String protocol, Integer priority) { + this.protocol = protocol; + this.priority = priority; + } + } } diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java index bd5377cf1b..fdbf8673c2 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java @@ -74,6 +74,10 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe } } + protected Session getConnection() { + return connection; + } + //----- Connection and Disconnection methods -----------------------------// public void connect() throws Exception { diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java index 745fdcfd92..a03fdd0dff 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java @@ -41,17 +41,13 @@ public class MQTTWSConnectionTimeoutTest extends WSTransportTestSupport { super.setUp(); wsMQTTConnection = new MQTTWSConnection(); - // WebSocketClientFactory clientFactory = new WebSocketClientFactory(); - //clientFactory.start(); - wsClient = new WebSocketClient(); wsClient.start(); ClientUpgradeRequest request = new ClientUpgradeRequest(); - request.setSubProtocols("mqtt"); + request.setSubProtocols("mqttv3.1"); wsClient.connect(wsMQTTConnection, wsConnectUri, request); - //wsClient.setProtocol("mqttv3.1"); if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) { throw new IOException("Could not connect to MQTT WS endpoint"); diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSubProtocolTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSubProtocolTest.java new file mode 100644 index 0000000000..ea8c262beb --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSubProtocolTest.java @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class MQTTWSSubProtocolTest extends WSTransportTestSupport { + + protected WebSocketClient wsClient; + protected MQTTWSConnection wsMQTTConnection; + protected ClientUpgradeRequest request; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + wsClient = new WebSocketClient(new SslContextFactory(true)); + wsClient.start(); + } + + @Override + @After + public void tearDown() throws Exception { + if (wsMQTTConnection != null) { + wsMQTTConnection.close(); + wsMQTTConnection = null; + wsClient = null; + } + + super.tearDown(); + } + + @Test(timeout = 60000) + public void testConnectv31() throws Exception { + connect("mqttv3.1"); + wsMQTTConnection.connect(); + assertEquals("mqttv3.1", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); + } + + @Test(timeout = 60000) + public void testConnectMqtt() throws Exception { + connect("mqtt"); + wsMQTTConnection.connect(); + assertEquals("mqtt", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); + } + + @Test(timeout = 60000) + public void testConnectMultiple() throws Exception { + connect("mqtt,mqttv3.1"); + wsMQTTConnection.connect(); + assertEquals("mqttv3.1", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); + } + + protected void connect(String subProtocol) throws Exception { + request = new ClientUpgradeRequest(); + if (subProtocol != null) { + request.setSubProtocols(subProtocol); + } + + wsMQTTConnection = new MQTTWSConnection(); + + wsClient.connect(wsMQTTConnection, wsConnectUri, request); + if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to MQTT WS endpoint"); + } + } + +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSSubProtocolTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSSubProtocolTest.java new file mode 100644 index 0000000000..492428ee95 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSSubProtocolTest.java @@ -0,0 +1,169 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import org.apache.activemq.transport.stomp.Stomp; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test STOMP sub protocol detection. + */ +public class StompWSSubProtocolTest extends WSTransportTestSupport { + + protected WebSocketClient wsClient; + protected StompWSConnection wsStompConnection; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + wsStompConnection = new StompWSConnection(); + } + + @Override + @After + public void tearDown() throws Exception { + if (wsStompConnection != null) { + wsStompConnection.close(); + wsStompConnection = null; + wsClient = null; + } + + super.tearDown(); + } + + @Test(timeout = 60000) + public void testConnectV12() throws Exception { + connect("v12.stomp"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("v12.stomp"); + } + + @Test(timeout = 60000) + public void testConnectV11() throws Exception { + connect("v11.stomp"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("v11.stomp"); + } + + @Test(timeout = 60000) + public void testConnectV10() throws Exception { + connect("v10.stomp"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("v10.stomp"); + } + + @Test(timeout = 60000) + public void testConnectNone() throws Exception { + + connect(null); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("stomp"); + } + + @Test(timeout = 60000) + public void testConnectMultiple() throws Exception { + + connect("v10.stomp,v11.stomp"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("v11.stomp"); + } + + @Test(timeout = 60000) + public void testConnectInvalid() throws Exception { + connect("invalid"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("stomp"); + } + + protected void connect(String subProtocol) throws Exception{ + ClientUpgradeRequest request = new ClientUpgradeRequest(); + if (subProtocol != null) { + request.setSubProtocols(subProtocol); + } + + wsClient = new WebSocketClient(new SslContextFactory(true)); + wsClient.start(); + + wsClient.connect(wsStompConnection, wsConnectUri, request); + if (!wsStompConnection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to STOMP WS endpoint"); + } + } + + protected void assertSubProtocol(String subProtocol) throws Exception { + String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS); + assertNotNull(incoming); + assertTrue(incoming.startsWith("CONNECTED")); + assertEquals(subProtocol, wsStompConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); + } + +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java index f84d7842db..dd493469e0 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java @@ -16,6 +16,7 @@ */ package org.apache.activemq.transport.ws; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -28,6 +29,7 @@ import org.apache.activemq.transport.stomp.Stomp; import org.apache.activemq.transport.stomp.StompFrame; import org.apache.activemq.util.Wait; import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; import org.junit.After; import org.junit.Before; @@ -51,10 +53,14 @@ public class StompWSTransportTest extends WSTransportTestSupport { super.setUp(); wsStompConnection = new StompWSConnection(); + + ClientUpgradeRequest request = new ClientUpgradeRequest(); + request.setSubProtocols("v11.stomp"); + wsClient = new WebSocketClient(new SslContextFactory(true)); wsClient.start(); - wsClient.connect(wsStompConnection, wsConnectUri); + wsClient.connect(wsStompConnection, wsConnectUri, request); if (!wsStompConnection.awaitConnection(30, TimeUnit.SECONDS)) { throw new IOException("Could not connect to STOMP WS endpoint"); } @@ -86,6 +92,7 @@ public class StompWSTransportTest extends WSTransportTestSupport { String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS); assertNotNull(incoming); assertTrue(incoming.startsWith("CONNECTED")); + assertEquals("v11.stomp", wsStompConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() {