diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java index 406741c636..5f67d8a3a0 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/AbstractMQTTSocket.java @@ -1,23 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.activemq.transport.ws; +import java.io.IOException; import java.security.cert.X509Certificate; import java.util.concurrent.CountDownLatch; import org.apache.activemq.broker.BrokerService; import org.apache.activemq.broker.BrokerServiceAware; +import org.apache.activemq.command.Command; import org.apache.activemq.transport.TransportSupport; import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor; import org.apache.activemq.transport.mqtt.MQTTProtocolConverter; import org.apache.activemq.transport.mqtt.MQTTTransport; import org.apache.activemq.transport.mqtt.MQTTWireFormat; +import org.apache.activemq.util.IOExceptionSupport; import org.apache.activemq.util.ServiceStopper; +import org.fusesource.mqtt.codec.MQTTFrame; public abstract class AbstractMQTTSocket extends TransportSupport implements MQTTTransport, BrokerServiceAware { - protected MQTTWireFormat wireFormat = new MQTTWireFormat(); - protected final CountDownLatch socketTransportStarted = new CountDownLatch(1); protected MQTTProtocolConverter protocolConverter = null; - private BrokerService brokerService; + protected MQTTWireFormat wireFormat = new MQTTWireFormat(); + protected final MQTTInactivityMonitor mqttInactivityMonitor = new MQTTInactivityMonitor(this, wireFormat); + protected final CountDownLatch socketTransportStarted = new CountDownLatch(1); + protected BrokerService brokerService; + protected volatile int receiveCounter; protected final String remoteAddress; public AbstractMQTTSocket(String remoteAddress) { @@ -25,38 +47,51 @@ public abstract class AbstractMQTTSocket extends TransportSupport implements MQT this.remoteAddress = remoteAddress; } - protected boolean transportStartedAtLeastOnce() { - return socketTransportStarted.getCount() == 0; + @Override + public void oneway(Object command) throws IOException { + try { + getProtocolConverter().onActiveMQCommand((Command)command); + } catch (Exception e) { + onException(IOExceptionSupport.create(e)); + } } - protected void doStart() throws Exception { - socketTransportStarted.countDown(); + @Override + public void sendToActiveMQ(Command command) { + doConsume(command); } @Override protected void doStop(ServiceStopper stopper) throws Exception { - } - - protected MQTTProtocolConverter getProtocolConverter() { - if( protocolConverter == null ) { - protocolConverter = new MQTTProtocolConverter(this, brokerService); - } - return protocolConverter; + mqttInactivityMonitor.stop(); + handleStopped(); } @Override - public int getReceiveCounter() { - return 0; + protected void doStart() throws Exception { + socketTransportStarted.countDown(); + mqttInactivityMonitor.setTransportListener(getTransportListener()); + mqttInactivityMonitor.startConnectChecker(wireFormat.getConnectAttemptTimeout()); } + //----- Abstract methods for subclasses to implement ---------------------// + @Override - public X509Certificate[] getPeerCertificates() { - return new X509Certificate[0]; - } + public abstract void sendToMQTT(MQTTFrame command) throws IOException; + + /** + * Called when the transport is stopping to allow the dervied classes + * a chance to close WebSocket resources. + * + * @throws IOException if an error occurs during the stop. + */ + public abstract void handleStopped() throws IOException; + + //----- Accessor methods -------------------------------------------------// @Override public MQTTInactivityMonitor getInactivityMonitor() { - return null; + return mqttInactivityMonitor; } @Override @@ -69,8 +104,32 @@ public abstract class AbstractMQTTSocket extends TransportSupport implements MQT return remoteAddress; } + @Override + public int getReceiveCounter() { + return receiveCounter; + } + + @Override + public X509Certificate[] getPeerCertificates() { + return new X509Certificate[0]; + } + @Override public void setBrokerService(BrokerService brokerService) { this.brokerService = brokerService; } + + //----- Internal support methods -----------------------------------------// + + protected MQTTProtocolConverter getProtocolConverter() { + if (protocolConverter == null) { + protocolConverter = new MQTTProtocolConverter(this, brokerService); + } + + return protocolConverter; + } + + protected boolean transportStartedAtLeastOnce() { + return socketTransportStarted.getCount() == 0; + } } diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/MQTTSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/MQTTSocket.java index 43f08e41da..7032b1f5ac 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/MQTTSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/MQTTSocket.java @@ -18,7 +18,6 @@ package org.apache.activemq.transport.ws.jetty8; import java.io.IOException; -import org.apache.activemq.command.Command; import org.apache.activemq.transport.ws.AbstractMQTTSocket; import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.IOExceptionSupport; @@ -31,23 +30,46 @@ import org.slf4j.LoggerFactory; public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinaryMessage { private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class); - Connection outbound; + + private Connection outbound; public MQTTSocket(String remoteAddress) { super(remoteAddress); } + @Override + public void sendToMQTT(MQTTFrame command) throws IOException { + ByteSequence bytes = wireFormat.marshal(command); + outbound.sendMessage(bytes.getData(), 0, bytes.getLength()); + } + + @Override + public void handleStopped() throws IOException { + if (outbound != null && outbound.isOpen()) { + outbound.close(); + } + } + + //----- WebSocket.OnTextMessage callback handlers ------------------------// + + @Override + public void onOpen(Connection connection) { + this.outbound = connection; + } + @Override public void onMessage(byte[] bytes, int offset, int length) { if (!transportStartedAtLeastOnce()) { - LOG.debug("Waiting for StompSocket to be properly started..."); + LOG.debug("Waiting for MQTTSocket to be properly started..."); try { socketTransportStarted.await(); } catch (InterruptedException e) { - LOG.warn("While waiting for StompSocket to be properly started, we got interrupted!! Should be okay, but you could see race conditions..."); + LOG.warn("While waiting for MQTTSocket to be properly started, we got interrupted!! Should be okay, but you could see race conditions..."); } } + receiveCounter += length; + try { MQTTFrame frame = (MQTTFrame)wireFormat.unmarshal(new ByteSequence(bytes, offset, length)); getProtocolConverter().onMQTTCommand(frame); @@ -56,12 +78,6 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinary } } - - @Override - public void onOpen(Connection connection) { - this.outbound = connection; - } - @Override public void onClose(int closeCode, String message) { try { @@ -70,25 +86,4 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinary LOG.warn("Failed to close WebSocket", e); } } - - @Override - public void oneway(Object command) throws IOException { - try { - getProtocolConverter().onActiveMQCommand((Command) command); - } catch (Exception e) { - onException(IOExceptionSupport.create(e)); - } - } - - @Override - public void sendToActiveMQ(Command command) { - doConsume(command); - } - - @Override - public void sendToMQTT(MQTTFrame command) throws IOException { - ByteSequence bytes = wireFormat.marshal(command); - outbound.sendMessage(bytes.getData(), 0, bytes.getLength()); - } - } diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/WSServlet.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/WSServlet.java index c5cb706434..91a4c32b25 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/WSServlet.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty8/WSServlet.java @@ -54,6 +54,7 @@ public class WSServlet extends WebSocketServlet { @Override public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol) { WebSocket socket; + if (protocol != null && protocol.startsWith("mqtt")) { socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(request)); } else { diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java index ef7631a057..dc49da743e 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java @@ -19,7 +19,6 @@ package org.apache.activemq.transport.ws.jetty9; import java.io.IOException; import java.nio.ByteBuffer; -import org.apache.activemq.command.Command; import org.apache.activemq.transport.ws.AbstractMQTTSocket; import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.IOExceptionSupport; @@ -33,40 +32,36 @@ import org.slf4j.LoggerFactory; public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener { private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class); - Session session; + + private Session session; public MQTTSocket(String remoteAddress) { super(remoteAddress); } - @Override - public void oneway(Object command) throws IOException { - try { - getProtocolConverter().onActiveMQCommand((Command) command); - } catch (Exception e) { - onException(IOExceptionSupport.create(e)); - } - } - - @Override - public void sendToActiveMQ(Command command) { - doConsume(command); - } - @Override public void sendToMQTT(MQTTFrame command) throws IOException { ByteSequence bytes = wireFormat.marshal(command); session.getRemote().sendBytes(ByteBuffer.wrap(bytes.getData(), 0, bytes.getLength())); } + @Override + public void handleStopped() throws IOException { + if (session != null && session.isOpen()) { + session.close(); + } + } + + //----- WebSocket.OnTextMessage callback handlers ------------------------// + @Override public void onWebSocketBinary(byte[] bytes, int offset, int length) { if (!transportStartedAtLeastOnce()) { - LOG.debug("Waiting for StompSocket to be properly started..."); + LOG.debug("Waiting for MQTTSocket to be properly started..."); try { socketTransportStarted.await(); } catch (InterruptedException e) { - LOG.warn("While waiting for StompSocket to be properly started, we got interrupted!! Should be okay, but you could see race conditions..."); + LOG.warn("While waiting for MQTTSocket to be properly started, we got interrupted!! Should be okay, but you could see race conditions..."); } } diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/util/HttpTransportUtilsTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/util/HttpTransportUtilsTest.java index 51fcd048b4..4c0a431e97 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/util/HttpTransportUtilsTest.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/util/HttpTransportUtilsTest.java @@ -1,3 +1,19 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.activemq.transport.util; import static org.junit.Assert.assertEquals; diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java new file mode 100644 index 0000000000..81788b7fc6 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java @@ -0,0 +1,253 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import java.io.IOException; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; + +import org.apache.activemq.transport.mqtt.MQTTWireFormat; +import org.apache.activemq.util.ByteSequence; +import org.eclipse.jetty.websocket.WebSocket; +import org.fusesource.hawtbuf.UTF8Buffer; +import org.fusesource.mqtt.codec.CONNACK; +import org.fusesource.mqtt.codec.CONNECT; +import org.fusesource.mqtt.codec.DISCONNECT; +import org.fusesource.mqtt.codec.MQTTFrame; +import org.fusesource.mqtt.codec.PINGREQ; +import org.fusesource.mqtt.codec.PINGRESP; +import org.fusesource.mqtt.codec.PUBACK; +import org.fusesource.mqtt.codec.PUBCOMP; +import org.fusesource.mqtt.codec.PUBLISH; +import org.fusesource.mqtt.codec.PUBREC; +import org.fusesource.mqtt.codec.PUBREL; +import org.fusesource.mqtt.codec.SUBACK; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implements a simple WebSocket based MQTT Client that can be used for unit testing. + */ +public class MQTTWSConnection implements WebSocket, WebSocket.OnBinaryMessage { + + private static final Logger LOG = LoggerFactory.getLogger(MQTTWSConnection.class); + + private static final MQTTFrame PING_RESP_FRAME = new PINGRESP().encode(); + + private Connection connection; + private final CountDownLatch connectLatch = new CountDownLatch(1); + private final MQTTWireFormat wireFormat = new MQTTWireFormat(); + + private final BlockingQueue prefetch = new LinkedBlockingDeque(); + + private int closeCode = -1; + private String closeMessage; + + public boolean isConnected() { + return connection != null ? connection.isOpen() : false; + } + + public void close() { + if (connection != null) { + connection.close(); + } + } + + //----- Connection and Disconnection methods -----------------------------// + + public void connect() throws Exception { + connect(UUID.randomUUID().toString()); + } + + public void connect(String clientId) throws Exception { + checkConnected(); + + CONNECT command = new CONNECT(); + + command.clientId(new UTF8Buffer(clientId)); + command.cleanSession(false); + command.version(3); + command.keepAlive((short) 0); + + ByteSequence payload = wireFormat.marshal(command.encode()); + connection.sendMessage(payload.data, 0, payload.length); + + MQTTFrame incoming = receive(15, TimeUnit.SECONDS); + if (incoming == null || incoming.messageType() != CONNACK.TYPE) { + throw new IOException("Failed to connect to remote service."); + } + } + + public void disconnect() throws Exception { + if (!isConnected()) { + return; + } + + DISCONNECT command = new DISCONNECT(); + ByteSequence payload = wireFormat.marshal(command.encode()); + connection.sendMessage(payload.data, 0, payload.length); + } + + //---- Send methods ------------------------------------------------------// + + public void sendFrame(MQTTFrame frame) throws Exception { + checkConnected(); + ByteSequence payload = wireFormat.marshal(frame); + connection.sendMessage(payload.data, 0, payload.length); + } + + public void keepAlive() throws Exception { + checkConnected(); + ByteSequence payload = wireFormat.marshal(new PINGREQ().encode()); + connection.sendMessage(payload.data, 0, payload.length); + } + + //----- Receive methods --------------------------------------------------// + + public MQTTFrame receive() throws Exception { + checkConnected(); + return prefetch.take(); + } + + public MQTTFrame receive(long timeout, TimeUnit unit) throws Exception { + checkConnected(); + return prefetch.poll(timeout, unit); + } + + public MQTTFrame receiveNoWait() throws Exception { + checkConnected(); + return prefetch.poll(); + } + + //---- Blocking state change calls ---------------------------------------// + + public void awaitConnection() throws InterruptedException { + connectLatch.await(); + } + + public boolean awaitConnection(long time, TimeUnit unit) throws InterruptedException { + return connectLatch.await(time, unit); + } + + //----- Property Accessors -----------------------------------------------// + + public int getCloseCode() { + return closeCode; + } + + public String getCloseMessage() { + return closeMessage; + } + + //----- WebSocket callback handlers --------------------------------------// + + @Override + public void onMessage(byte[] data, int offset, int length) { + if (data ==null || length <= 0) { + return; + } + + MQTTFrame frame = null; + + try { + frame = (MQTTFrame)wireFormat.unmarshal(new ByteSequence(data, offset, length)); + } catch (IOException e) { + LOG.error("Could not decode incoming MQTT Frame: ", e.getMessage()); + connection.close(); + } + + try { + switch (frame.messageType()) { + case PINGREQ.TYPE: + PINGREQ ping = new PINGREQ().decode(frame); + LOG.info("WS-Client read frame: {}", ping); + sendFrame(PING_RESP_FRAME); + break; + case PINGRESP.TYPE: + LOG.info("WS-Client ping response received."); + break; + case CONNACK.TYPE: + CONNACK connAck = new CONNACK().decode(frame); + LOG.info("WS-Client read frame: {}", connAck); + prefetch.put(frame); + break; + case SUBACK.TYPE: + SUBACK subAck = new SUBACK().decode(frame); + LOG.info("WS-Client read frame: {}", subAck); + prefetch.put(frame); + break; + case PUBLISH.TYPE: + PUBLISH publish = new PUBLISH().decode(frame); + LOG.info("WS-Client read frame: {}", publish); + prefetch.put(frame); + break; + case PUBACK.TYPE: + PUBACK pubAck = new PUBACK().decode(frame); + LOG.info("WS-Client read frame: {}", pubAck); + prefetch.put(frame); + break; + case PUBREC.TYPE: + PUBREC pubRec = new PUBREC().decode(frame); + LOG.info("WS-Client read frame: {}", pubRec); + prefetch.put(frame); + break; + case PUBREL.TYPE: + PUBREL pubRel = new PUBREL().decode(frame); + LOG.info("WS-Client read frame: {}", pubRel); + prefetch.put(frame); + break; + case PUBCOMP.TYPE: + PUBCOMP pubComp = new PUBCOMP().decode(frame); + LOG.info("WS-Client read frame: {}", pubComp); + prefetch.put(frame); + break; + default: + LOG.error("Unknown MQTT Frame received."); + connection.close(); + } + } catch (Exception e) { + LOG.error("Could not decode incoming MQTT Frame: ", e.getMessage()); + connection.close(); + } + } + + @Override + public void onOpen(Connection connection) { + this.connection = connection; + this.connectLatch.countDown(); + } + + @Override + public void onClose(int closeCode, String message) { + LOG.trace("MQTT WS Connection closed, code:{} message:{}", closeCode, message); + + this.connection = null; + this.closeCode = closeCode; + this.closeMessage = message; + } + + //----- Internal implementation ------------------------------------------// + + private void checkConnected() throws IOException { + if (!isConnected()) { + throw new IOException("MQTT WS Connection is closed."); + } + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java new file mode 100644 index 0000000000..d587371861 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnectionTimeoutTest.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Vector; +import java.util.concurrent.TimeUnit; + +import org.apache.activemq.util.Wait; +import org.eclipse.jetty.websocket.WebSocketClient; +import org.eclipse.jetty.websocket.WebSocketClientFactory; +import org.junit.Before; +import org.junit.Test; + +public class MQTTWSConnectionTimeoutTest extends WSTransportTestSupport { + + protected WebSocketClient wsClient; + protected MQTTWSConnection wsMQTTConnection; + + protected Vector exceptions = new Vector(); + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + WebSocketClientFactory clientFactory = new WebSocketClientFactory(); + clientFactory.start(); + + wsClient = clientFactory.newWebSocketClient(); + wsClient.setProtocol("mqttv3.1"); + wsMQTTConnection = new MQTTWSConnection(); + + wsClient.open(wsConnectUri, wsMQTTConnection); + if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to MQTT WS endpoint"); + } + } + + protected String getConnectorScheme() { + return "ws"; + } + + @Test(timeout = 90000) + public void testInactivityMonitor() throws Exception { + + assertTrue("one connection", Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisified() throws Exception { + return 1 == broker.getTransportConnectorByScheme(getConnectorScheme()).connectionCount(); + } + }, TimeUnit.SECONDS.toMillis(15), TimeUnit.MILLISECONDS.toMillis(250))); + + // and it should be closed due to inactivity + assertTrue("no dangling connections", Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisified() throws Exception { + return 0 == broker.getTransportConnectorByScheme(getConnectorScheme()).connectionCount(); + } + }, TimeUnit.SECONDS.toMillis(60), TimeUnit.MILLISECONDS.toMillis(500))); + + assertTrue("no exceptions", exceptions.isEmpty()); + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSLinkStealingTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSLinkStealingTest.java new file mode 100644 index 0000000000..2d94eaa895 --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSLinkStealingTest.java @@ -0,0 +1,130 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import org.apache.activemq.util.Wait; +import org.eclipse.jetty.websocket.WebSocketClient; +import org.eclipse.jetty.websocket.WebSocketClientFactory; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test that a WS client can steal links when enabled. + */ +public class MQTTWSLinkStealingTest extends WSTransportTestSupport { + + private final String CLIENT_ID = "WS-CLIENT-ID"; + + protected WebSocketClient wsClient; + protected MQTTWSConnection wsMQTTConnection; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + WebSocketClientFactory clientFactory = new WebSocketClientFactory(); + clientFactory.start(); + + wsClient = clientFactory.newWebSocketClient(); + wsClient.setProtocol("mqttv3.1"); + + wsMQTTConnection = new MQTTWSConnection(); + + wsClient.open(wsConnectUri, wsMQTTConnection); + if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to MQTT WS endpoint"); + } + } + + @Override + @After + public void tearDown() throws Exception { + if (wsMQTTConnection != null) { + wsMQTTConnection.close(); + wsMQTTConnection = null; + wsClient = null; + } + + super.tearDown(); + } + + @Test(timeout = 60000) + public void testConnectAndStealLink() throws Exception { + + wsMQTTConnection.connect(CLIENT_ID); + + assertTrue("Connection should open", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 1; + } + })); + + WebSocketClientFactory theifFactory = new WebSocketClientFactory(); + theifFactory.start(); + + MQTTWSConnection theif = new MQTTWSConnection(); + + wsClient.open(wsConnectUri, theif); + if (!theif.awaitConnection(30, TimeUnit.SECONDS)) { + fail("Could not open new WS connection for link stealing client"); + } + + theif.connect(CLIENT_ID); + + assertTrue("Connection should open", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 1; + } + })); + + assertTrue("Original Connection should close", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return !wsMQTTConnection.isConnected(); + } + })); + + theif.disconnect(); + theif.close(); + + assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 0; + } + })); + } + + @Override + protected boolean isAllowLinkStealing() { + return true; + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSTransportTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSTransportTest.java new file mode 100644 index 0000000000..f304ada8af --- /dev/null +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSTransportTest.java @@ -0,0 +1,216 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.ws; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.IOException; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.activemq.util.Wait; +import org.eclipse.jetty.websocket.WebSocketClient; +import org.eclipse.jetty.websocket.WebSocketClientFactory; +import org.fusesource.hawtbuf.UTF8Buffer; +import org.fusesource.mqtt.codec.CONNACK; +import org.fusesource.mqtt.codec.CONNECT; +import org.fusesource.mqtt.codec.MQTTFrame; +import org.fusesource.mqtt.codec.PINGREQ; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class MQTTWSTransportTest extends WSTransportTestSupport { + + protected WebSocketClient wsClient; + protected MQTTWSConnection wsMQTTConnection; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + WebSocketClientFactory clientFactory = new WebSocketClientFactory(); + clientFactory.start(); + + wsClient = clientFactory.newWebSocketClient(); + wsClient.setProtocol("mqttv3.1"); + + wsMQTTConnection = new MQTTWSConnection(); + + wsClient.open(wsConnectUri, wsMQTTConnection); + if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) { + throw new IOException("Could not connect to MQTT WS endpoint"); + } + } + + @Override + @After + public void tearDown() throws Exception { + if (wsMQTTConnection != null) { + wsMQTTConnection.close(); + wsMQTTConnection = null; + wsClient = null; + } + + super.tearDown(); + } + + @Test(timeout = 60000) + public void testConnect() throws Exception { + + wsMQTTConnection.connect(); + + assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 1; + } + })); + + wsMQTTConnection.disconnect(); + wsMQTTConnection.close(); + + assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 0; + } + })); + } + + @Test(timeout = 60000) + public void testConnectWithNoHeartbeatsClosesConnection() throws Exception { + + CONNECT command = new CONNECT(); + + command.clientId(new UTF8Buffer(UUID.randomUUID().toString())); + command.cleanSession(false); + command.version(3); + command.keepAlive((short) 2); + + wsMQTTConnection.sendFrame(command.encode()); + + MQTTFrame received = wsMQTTConnection.receive(15, TimeUnit.SECONDS); + if (received == null || received.messageType() != CONNACK.TYPE) { + fail("Client did not get expected CONNACK"); + } + + assertTrue("Connection should open", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 1; + } + })); + + assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 0; + } + })); + + assertTrue("Client Connection should close", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return !wsMQTTConnection.isConnected(); + } + })); + } + + @Test(timeout = 60000) + public void testConnectWithHeartbeatsKeepsConnectionAlive() throws Exception { + + final AtomicBoolean done = new AtomicBoolean(); + + CONNECT command = new CONNECT(); + + command.clientId(new UTF8Buffer(UUID.randomUUID().toString())); + command.cleanSession(false); + command.version(3); + command.keepAlive((short) 2); + + wsMQTTConnection.sendFrame(command.encode()); + + MQTTFrame received = wsMQTTConnection.receive(15, TimeUnit.SECONDS); + if (received == null || received.messageType() != CONNACK.TYPE) { + fail("Client did not get expected CONNACK"); + } + + Thread pinger = new Thread(new Runnable() { + + @Override + public void run() { + try { + while (!done.get()) { + TimeUnit.SECONDS.sleep(1); + wsMQTTConnection.sendFrame(new PINGREQ().encode()); + } + } catch (Exception e) { + } + } + }); + + pinger.start(); + + assertTrue("Connection should open", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 1; + } + })); + + TimeUnit.SECONDS.sleep(10); + + assertTrue("Connection should still open", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 1; + } + })); + + wsMQTTConnection.disconnect(); + wsMQTTConnection.close(); + + done.set(true); + + assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return getProxyToBroker().getCurrentConnectionsCount() == 0; + } + })); + + assertTrue("Client Connection should close", Wait.waitFor(new Wait.Condition() { + + @Override + public boolean isSatisified() throws Exception { + return !wsMQTTConnection.isConnected(); + } + })); + } +} diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/SocketTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/SocketTest.java index 871029a7af..c44d672ed8 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/SocketTest.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/SocketTest.java @@ -1,3 +1,19 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.activemq.transport.ws; import static org.junit.Assert.assertEquals; @@ -33,5 +49,4 @@ public class SocketTest { assertEquals("ws://localhost:8080", mqttSocketJetty9.getRemoteAddress()); } - } diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java index 6ab86ff8a1..c7452215f9 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/WSTransportTestSupport.java @@ -70,7 +70,14 @@ public class WSTransportTestSupport { } protected String getWSConnectorURI() { - return "ws://127.0.0.1:" + getProxyPort() + "?websocket.maxTextMessageSize=99999&transport.maxIdleTime=1001"; + return "ws://127.0.0.1:" + getProxyPort() + + "?allowLinkStealing=" + isAllowLinkStealing() + + "&websocket.maxTextMessageSize=99999&" + + "transport.maxIdleTime=1001"; + } + + protected boolean isAllowLinkStealing() { + return false; } protected void addAdditionalConnectors(BrokerService service) throws Exception { diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java index aaad323dee..8c56a24c9e 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTInactivityMonitor.java @@ -221,7 +221,7 @@ public class MQTTInactivityMonitor extends TransportFilter { return protocolConverter; } - synchronized void startConnectChecker(long connectionTimeout) { + public synchronized void startConnectChecker(long connectionTimeout) { this.connectionTimeout = connectionTimeout; if (connectionTimeout > 0 && connectCheckerTask == null) { connectCheckerTask = new SchedulerTimerTask(connectChecker); diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTTransportFilter.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTTransportFilter.java index da84b1ae1d..a347371c32 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTTransportFilter.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTTransportFilter.java @@ -60,7 +60,6 @@ public class MQTTTransportFilter extends TransportFilter implements MQTTTranspor private MQTTInactivityMonitor monitor; private MQTTWireFormat wireFormat; private final AtomicBoolean stopped = new AtomicBoolean(); - private long connectAttemptTimeout = MQTTWireFormat.DEFAULT_CONNECTION_TIMEOUT; private boolean trace; private final Object sendLock = new Object(); @@ -216,7 +215,7 @@ public class MQTTTransportFilter extends TransportFilter implements MQTTTranspor * @return the timeout value used to fail a connection if no CONNECT frame read. */ public long getConnectAttemptTimeout() { - return connectAttemptTimeout; + return wireFormat.getConnectAttemptTimeout(); } /** @@ -227,7 +226,7 @@ public class MQTTTransportFilter extends TransportFilter implements MQTTTranspor * the connection frame received timeout value. */ public void setConnectAttemptTimeout(long connectTimeout) { - this.connectAttemptTimeout = connectTimeout; + this.setConnectAttemptTimeout(connectTimeout); } public boolean getPublishDollarTopics() { diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTWireFormat.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTWireFormat.java index fe1c6aa6a2..2182f67711 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTWireFormat.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTWireFormat.java @@ -41,6 +41,7 @@ public class MQTTWireFormat implements WireFormat { private int version = 1; private int maxFrameSize = MAX_MESSAGE_LENGTH; + private long connectAttemptTimeout = MQTTWireFormat.DEFAULT_CONNECTION_TIMEOUT; @Override public ByteSequence marshal(Object command) throws IOException { @@ -144,4 +145,22 @@ public class MQTTWireFormat implements WireFormat { public void setMaxFrameSize(int maxFrameSize) { this.maxFrameSize = Math.min(MAX_MESSAGE_LENGTH, maxFrameSize); } + + /** + * @return the timeout value used to fail a connection if no CONNECT frame read. + */ + public long getConnectAttemptTimeout() { + return connectAttemptTimeout; + } + + /** + * Sets the timeout value used to fail a connection if no CONNECT frame is read + * in the given interval. + * + * @param connectTimeout + * the connection frame received timeout value. + */ + public void setConnectAttemptTimeout(long connectTimeout) { + this.connectAttemptTimeout = connectTimeout; + } }