diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java index af7f86058a..3ded98f1f6 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/StompWSConnection.java @@ -55,6 +55,10 @@ public class StompWSConnection extends WebSocketAdapter implements WebSocketList } } + protected Session getConnection() { + return connection; + } + //---- Send methods ------------------------------------------------------// public synchronized void sendRawFrame(String rawFrame) throws Exception { diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java index 981c2ff80f..cbcdc611ce 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/WSServlet.java @@ -18,6 +18,12 @@ package org.apache.activemq.transport.ws.jetty9; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -42,6 +48,19 @@ public class WSServlet extends WebSocketServlet { private TransportAcceptListener listener; + private final static Map stompProtocols = new ConcurrentHashMap<> (); + private final static Map mqttProtocols = new ConcurrentHashMap<> (); + + static { + stompProtocols.put("v12.stomp", 3); + stompProtocols.put("v11.stomp", 2); + stompProtocols.put("v10.stomp", 1); + stompProtocols.put("stomp", 0); + + mqttProtocols.put("mqttv3.1", 1); + mqttProtocols.put("mqtt", 0); + } + @Override public void init() throws ServletException { super.init(); @@ -70,16 +89,51 @@ public class WSServlet extends WebSocketServlet { } if (isMqtt) { socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); - resp.setAcceptedSubProtocol("mqtt"); + resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols,req.getSubProtocols(), "mqtt")); ((MQTTSocket)socket).setPeerCertificates(req.getCertificates()); } else { socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); ((StompSocket)socket).setCertificates(req.getCertificates()); - resp.setAcceptedSubProtocol("stomp"); + resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols,req.getSubProtocols(), "stomp")); } listener.onAccept((Transport) socket); return socket; } }); } + + private String getAcceptedSubProtocol(final Map protocols, + List subProtocols, String defaultProtocol) { + List matchedProtocols = new ArrayList<>(); + if (subProtocols != null && subProtocols.size() > 0) { + //detect which subprotocols match accepted protocols and add to the list + for (String subProtocol : subProtocols) { + Integer priority = protocols.get(subProtocol); + if(subProtocol != null && priority != null) { + //only insert if both subProtocol and priority are not null + matchedProtocols.add(new SubProtocol(subProtocol, priority)); + } + } + //sort the list by priority + if (matchedProtocols.size() > 0) { + Collections.sort(matchedProtocols, new Comparator() { + @Override + public int compare(SubProtocol s1, SubProtocol s2) { + return s2.priority.compareTo(s1.priority); + } + }); + return matchedProtocols.get(0).protocol; + } + } + return defaultProtocol; + } + + private class SubProtocol { + private String protocol; + private Integer priority; + public SubProtocol(String protocol, Integer priority) { + this.protocol = protocol; + this.priority = priority; + } + } } diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java index bd5377cf1b..fdbf8673c2 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java @@ -74,6 +74,10 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe } } + protected Session getConnection() { + return connection; + } + //----- Connection and Disconnection methods -----------------------------// public void connect() throws Exception { diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java index 745fdcfd92..a03fdd0dff 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java @@ -41,17 +41,13 @@ public class MQTTWSConnectionTimeoutTest extends WSTransportTestSupport { super.setUp(); wsMQTTConnection = new MQTTWSConnection(); - // WebSocketClientFactory clientFactory = new WebSocketClientFactory(); - //clientFactory.start(); - wsClient = new WebSocketClient(); wsClient.start(); ClientUpgradeRequest request = new ClientUpgradeRequest(); - request.setSubProtocols("mqtt"); + request.setSubProtocols("mqttv3.1"); wsClient.connect(wsMQTTConnection, wsConnectUri, request); - //wsClient.setProtocol("mqttv3.1"); if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) { throw new IOException("Could not connect to MQTT WS endpoint"); diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSubProtocolTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSubProtocolTest.java new file mode 100644 index 0000000000..ea8c262beb --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSubProtocolTest.java @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class MQTTWSSubProtocolTest extends WSTransportTestSupport { + + protected WebSocketClient wsClient; + protected MQTTWSConnection wsMQTTConnection; + protected ClientUpgradeRequest request; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + wsClient = new WebSocketClient(new SslContextFactory(true)); + wsClient.start(); + } + + @Override + @After + public void tearDown() throws Exception { + if (wsMQTTConnection != null) { + wsMQTTConnection.close(); + wsMQTTConnection = null; + wsClient = null; + } + + super.tearDown(); + } + + @Test(timeout = 60000) + public void testConnectv31() throws Exception { + connect("mqttv3.1"); + wsMQTTConnection.connect(); + assertEquals("mqttv3.1", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); + } + + @Test(timeout = 60000) + public void testConnectMqtt() throws Exception { + connect("mqtt"); + wsMQTTConnection.connect(); + assertEquals("mqtt", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); + } + + @Test(timeout = 60000) + public void testConnectMultiple() throws Exception { + connect("mqtt,mqttv3.1"); + wsMQTTConnection.connect(); + assertEquals("mqttv3.1", wsMQTTConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); + } + + protected void connect(String subProtocol) throws Exception { + request = new ClientUpgradeRequest(); + if (subProtocol != null) { + request.setSubProtocols(subProtocol); + } + + wsMQTTConnection = new MQTTWSConnection(); + + wsClient.connect(wsMQTTConnection, wsConnectUri, request); + if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to MQTT WS endpoint"); + } + } + +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSSubProtocolTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSSubProtocolTest.java new file mode 100644 index 0000000000..492428ee95 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSSubProtocolTest.java @@ -0,0 +1,169 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import org.apache.activemq.transport.stomp.Stomp; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test STOMP sub protocol detection. + */ +public class StompWSSubProtocolTest extends WSTransportTestSupport { + + protected WebSocketClient wsClient; + protected StompWSConnection wsStompConnection; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + wsStompConnection = new StompWSConnection(); + } + + @Override + @After + public void tearDown() throws Exception { + if (wsStompConnection != null) { + wsStompConnection.close(); + wsStompConnection = null; + wsClient = null; + } + + super.tearDown(); + } + + @Test(timeout = 60000) + public void testConnectV12() throws Exception { + connect("v12.stomp"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("v12.stomp"); + } + + @Test(timeout = 60000) + public void testConnectV11() throws Exception { + connect("v11.stomp"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("v11.stomp"); + } + + @Test(timeout = 60000) + public void testConnectV10() throws Exception { + connect("v10.stomp"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("v10.stomp"); + } + + @Test(timeout = 60000) + public void testConnectNone() throws Exception { + + connect(null); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("stomp"); + } + + @Test(timeout = 60000) + public void testConnectMultiple() throws Exception { + + connect("v10.stomp,v11.stomp"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("v11.stomp"); + } + + @Test(timeout = 60000) + public void testConnectInvalid() throws Exception { + connect("invalid"); + + String connectFrame = "STOMP\n" + + "accept-version:1.2\n" + + "host:localhost\n" + + "\n" + Stomp.NULL; + + wsStompConnection.sendRawFrame(connectFrame); + + assertSubProtocol("stomp"); + } + + protected void connect(String subProtocol) throws Exception{ + ClientUpgradeRequest request = new ClientUpgradeRequest(); + if (subProtocol != null) { + request.setSubProtocols(subProtocol); + } + + wsClient = new WebSocketClient(new SslContextFactory(true)); + wsClient.start(); + + wsClient.connect(wsStompConnection, wsConnectUri, request); + if (!wsStompConnection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to STOMP WS endpoint"); + } + } + + protected void assertSubProtocol(String subProtocol) throws Exception { + String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS); + assertNotNull(incoming); + assertTrue(incoming.startsWith("CONNECTED")); + assertEquals(subProtocol, wsStompConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); + } + +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java index f84d7842db..dd493469e0 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/StompWSTransportTest.java @@ -16,6 +16,7 @@ */ package org.apache.activemq.transport.ws; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -28,6 +29,7 @@ import org.apache.activemq.transport.stomp.Stomp; import org.apache.activemq.transport.stomp.StompFrame; import org.apache.activemq.util.Wait; import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; import org.junit.After; import org.junit.Before; @@ -51,10 +53,14 @@ public class StompWSTransportTest extends WSTransportTestSupport { super.setUp(); wsStompConnection = new StompWSConnection(); + + ClientUpgradeRequest request = new ClientUpgradeRequest(); + request.setSubProtocols("v11.stomp"); + wsClient = new WebSocketClient(new SslContextFactory(true)); wsClient.start(); - wsClient.connect(wsStompConnection, wsConnectUri); + wsClient.connect(wsStompConnection, wsConnectUri, request); if (!wsStompConnection.awaitConnection(30, TimeUnit.SECONDS)) { throw new IOException("Could not connect to STOMP WS endpoint"); } @@ -86,6 +92,7 @@ public class StompWSTransportTest extends WSTransportTestSupport { String incoming = wsStompConnection.receive(30, TimeUnit.SECONDS); assertNotNull(incoming); assertTrue(incoming.startsWith("CONNECTED")); + assertEquals("v11.stomp", wsStompConnection.getConnection().getUpgradeResponse().getAcceptedSubProtocol()); assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() {