Clean up the MQTT over WebSocket code to allow for handling link
stealing and inactivity monitor.  Ensures that the web socket instances
get cleaned up on errors and avoids leaks that might otherwise arise.
Adds new tests for MQTT over WebSocket.

Adds some missing license headers as well.
This commit is contained in:
Timothy Bish 2015-06-29 18:35:08 -04:00
parent 06202097a2
commit 27edaffded
14 changed files with 860 additions and 75 deletions

View File

@ -1,23 +1,45 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.ws;
import java.io.IOException;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.BrokerServiceAware;
import org.apache.activemq.command.Command;
import org.apache.activemq.transport.TransportSupport;
import org.apache.activemq.transport.mqtt.MQTTInactivityMonitor;
import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
import org.apache.activemq.transport.mqtt.MQTTTransport;
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
import org.apache.activemq.util.IOExceptionSupport;
import org.apache.activemq.util.ServiceStopper;
import org.fusesource.mqtt.codec.MQTTFrame;
public abstract class AbstractMQTTSocket extends TransportSupport implements MQTTTransport, BrokerServiceAware {
protected MQTTWireFormat wireFormat = new MQTTWireFormat();
protected final CountDownLatch socketTransportStarted = new CountDownLatch(1);
protected MQTTProtocolConverter protocolConverter = null;
private BrokerService brokerService;
protected MQTTWireFormat wireFormat = new MQTTWireFormat();
protected final MQTTInactivityMonitor mqttInactivityMonitor = new MQTTInactivityMonitor(this, wireFormat);
protected final CountDownLatch socketTransportStarted = new CountDownLatch(1);
protected BrokerService brokerService;
protected volatile int receiveCounter;
protected final String remoteAddress;
public AbstractMQTTSocket(String remoteAddress) {
@ -25,38 +47,51 @@ public abstract class AbstractMQTTSocket extends TransportSupport implements MQT
this.remoteAddress = remoteAddress;
}
protected boolean transportStartedAtLeastOnce() {
return socketTransportStarted.getCount() == 0;
@Override
public void oneway(Object command) throws IOException {
try {
getProtocolConverter().onActiveMQCommand((Command)command);
} catch (Exception e) {
onException(IOExceptionSupport.create(e));
}
}
protected void doStart() throws Exception {
socketTransportStarted.countDown();
@Override
public void sendToActiveMQ(Command command) {
doConsume(command);
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
}
protected MQTTProtocolConverter getProtocolConverter() {
if( protocolConverter == null ) {
protocolConverter = new MQTTProtocolConverter(this, brokerService);
}
return protocolConverter;
mqttInactivityMonitor.stop();
handleStopped();
}
@Override
public int getReceiveCounter() {
return 0;
protected void doStart() throws Exception {
socketTransportStarted.countDown();
mqttInactivityMonitor.setTransportListener(getTransportListener());
mqttInactivityMonitor.startConnectChecker(wireFormat.getConnectAttemptTimeout());
}
//----- Abstract methods for subclasses to implement ---------------------//
@Override
public X509Certificate[] getPeerCertificates() {
return new X509Certificate[0];
}
public abstract void sendToMQTT(MQTTFrame command) throws IOException;
/**
* Called when the transport is stopping to allow the dervied classes
* a chance to close WebSocket resources.
*
* @throws IOException if an error occurs during the stop.
*/
public abstract void handleStopped() throws IOException;
//----- Accessor methods -------------------------------------------------//
@Override
public MQTTInactivityMonitor getInactivityMonitor() {
return null;
return mqttInactivityMonitor;
}
@Override
@ -69,8 +104,32 @@ public abstract class AbstractMQTTSocket extends TransportSupport implements MQT
return remoteAddress;
}
@Override
public int getReceiveCounter() {
return receiveCounter;
}
@Override
public X509Certificate[] getPeerCertificates() {
return new X509Certificate[0];
}
@Override
public void setBrokerService(BrokerService brokerService) {
this.brokerService = brokerService;
}
//----- Internal support methods -----------------------------------------//
protected MQTTProtocolConverter getProtocolConverter() {
if (protocolConverter == null) {
protocolConverter = new MQTTProtocolConverter(this, brokerService);
}
return protocolConverter;
}
protected boolean transportStartedAtLeastOnce() {
return socketTransportStarted.getCount() == 0;
}
}

View File

@ -18,7 +18,6 @@ package org.apache.activemq.transport.ws.jetty8;
import java.io.IOException;
import org.apache.activemq.command.Command;
import org.apache.activemq.transport.ws.AbstractMQTTSocket;
import org.apache.activemq.util.ByteSequence;
import org.apache.activemq.util.IOExceptionSupport;
@ -31,23 +30,46 @@ import org.slf4j.LoggerFactory;
public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinaryMessage {
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
Connection outbound;
private Connection outbound;
public MQTTSocket(String remoteAddress) {
super(remoteAddress);
}
@Override
public void sendToMQTT(MQTTFrame command) throws IOException {
ByteSequence bytes = wireFormat.marshal(command);
outbound.sendMessage(bytes.getData(), 0, bytes.getLength());
}
@Override
public void handleStopped() throws IOException {
if (outbound != null && outbound.isOpen()) {
outbound.close();
}
}
//----- WebSocket.OnTextMessage callback handlers ------------------------//
@Override
public void onOpen(Connection connection) {
this.outbound = connection;
}
@Override
public void onMessage(byte[] bytes, int offset, int length) {
if (!transportStartedAtLeastOnce()) {
LOG.debug("Waiting for StompSocket to be properly started...");
LOG.debug("Waiting for MQTTSocket to be properly started...");
try {
socketTransportStarted.await();
} catch (InterruptedException e) {
LOG.warn("While waiting for StompSocket to be properly started, we got interrupted!! Should be okay, but you could see race conditions...");
LOG.warn("While waiting for MQTTSocket to be properly started, we got interrupted!! Should be okay, but you could see race conditions...");
}
}
receiveCounter += length;
try {
MQTTFrame frame = (MQTTFrame)wireFormat.unmarshal(new ByteSequence(bytes, offset, length));
getProtocolConverter().onMQTTCommand(frame);
@ -56,12 +78,6 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinary
}
}
@Override
public void onOpen(Connection connection) {
this.outbound = connection;
}
@Override
public void onClose(int closeCode, String message) {
try {
@ -70,25 +86,4 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocket.OnBinary
LOG.warn("Failed to close WebSocket", e);
}
}
@Override
public void oneway(Object command) throws IOException {
try {
getProtocolConverter().onActiveMQCommand((Command) command);
} catch (Exception e) {
onException(IOExceptionSupport.create(e));
}
}
@Override
public void sendToActiveMQ(Command command) {
doConsume(command);
}
@Override
public void sendToMQTT(MQTTFrame command) throws IOException {
ByteSequence bytes = wireFormat.marshal(command);
outbound.sendMessage(bytes.getData(), 0, bytes.getLength());
}
}

View File

@ -54,6 +54,7 @@ public class WSServlet extends WebSocketServlet {
@Override
public WebSocket doWebSocketConnect(HttpServletRequest request, String protocol) {
WebSocket socket;
if (protocol != null && protocol.startsWith("mqtt")) {
socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(request));
} else {

View File

@ -19,7 +19,6 @@ package org.apache.activemq.transport.ws.jetty9;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.activemq.command.Command;
import org.apache.activemq.transport.ws.AbstractMQTTSocket;
import org.apache.activemq.util.ByteSequence;
import org.apache.activemq.util.IOExceptionSupport;
@ -33,40 +32,36 @@ import org.slf4j.LoggerFactory;
public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener {
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
Session session;
private Session session;
public MQTTSocket(String remoteAddress) {
super(remoteAddress);
}
@Override
public void oneway(Object command) throws IOException {
try {
getProtocolConverter().onActiveMQCommand((Command) command);
} catch (Exception e) {
onException(IOExceptionSupport.create(e));
}
}
@Override
public void sendToActiveMQ(Command command) {
doConsume(command);
}
@Override
public void sendToMQTT(MQTTFrame command) throws IOException {
ByteSequence bytes = wireFormat.marshal(command);
session.getRemote().sendBytes(ByteBuffer.wrap(bytes.getData(), 0, bytes.getLength()));
}
@Override
public void handleStopped() throws IOException {
if (session != null && session.isOpen()) {
session.close();
}
}
//----- WebSocket.OnTextMessage callback handlers ------------------------//
@Override
public void onWebSocketBinary(byte[] bytes, int offset, int length) {
if (!transportStartedAtLeastOnce()) {
LOG.debug("Waiting for StompSocket to be properly started...");
LOG.debug("Waiting for MQTTSocket to be properly started...");
try {
socketTransportStarted.await();
} catch (InterruptedException e) {
LOG.warn("While waiting for StompSocket to be properly started, we got interrupted!! Should be okay, but you could see race conditions...");
LOG.warn("While waiting for MQTTSocket to be properly started, we got interrupted!! Should be okay, but you could see race conditions...");
}
}

View File

@ -1,3 +1,19 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.util;
import static org.junit.Assert.assertEquals;

View File

@ -0,0 +1,253 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.ws;
import java.io.IOException;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import org.apache.activemq.transport.mqtt.MQTTWireFormat;
import org.apache.activemq.util.ByteSequence;
import org.eclipse.jetty.websocket.WebSocket;
import org.fusesource.hawtbuf.UTF8Buffer;
import org.fusesource.mqtt.codec.CONNACK;
import org.fusesource.mqtt.codec.CONNECT;
import org.fusesource.mqtt.codec.DISCONNECT;
import org.fusesource.mqtt.codec.MQTTFrame;
import org.fusesource.mqtt.codec.PINGREQ;
import org.fusesource.mqtt.codec.PINGRESP;
import org.fusesource.mqtt.codec.PUBACK;
import org.fusesource.mqtt.codec.PUBCOMP;
import org.fusesource.mqtt.codec.PUBLISH;
import org.fusesource.mqtt.codec.PUBREC;
import org.fusesource.mqtt.codec.PUBREL;
import org.fusesource.mqtt.codec.SUBACK;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Implements a simple WebSocket based MQTT Client that can be used for unit testing.
*/
public class MQTTWSConnection implements WebSocket, WebSocket.OnBinaryMessage {
private static final Logger LOG = LoggerFactory.getLogger(MQTTWSConnection.class);
private static final MQTTFrame PING_RESP_FRAME = new PINGRESP().encode();
private Connection connection;
private final CountDownLatch connectLatch = new CountDownLatch(1);
private final MQTTWireFormat wireFormat = new MQTTWireFormat();
private final BlockingQueue<MQTTFrame> prefetch = new LinkedBlockingDeque<MQTTFrame>();
private int closeCode = -1;
private String closeMessage;
public boolean isConnected() {
return connection != null ? connection.isOpen() : false;
}
public void close() {
if (connection != null) {
connection.close();
}
}
//----- Connection and Disconnection methods -----------------------------//
public void connect() throws Exception {
connect(UUID.randomUUID().toString());
}
public void connect(String clientId) throws Exception {
checkConnected();
CONNECT command = new CONNECT();
command.clientId(new UTF8Buffer(clientId));
command.cleanSession(false);
command.version(3);
command.keepAlive((short) 0);
ByteSequence payload = wireFormat.marshal(command.encode());
connection.sendMessage(payload.data, 0, payload.length);
MQTTFrame incoming = receive(15, TimeUnit.SECONDS);
if (incoming == null || incoming.messageType() != CONNACK.TYPE) {
throw new IOException("Failed to connect to remote service.");
}
}
public void disconnect() throws Exception {
if (!isConnected()) {
return;
}
DISCONNECT command = new DISCONNECT();
ByteSequence payload = wireFormat.marshal(command.encode());
connection.sendMessage(payload.data, 0, payload.length);
}
//---- Send methods ------------------------------------------------------//
public void sendFrame(MQTTFrame frame) throws Exception {
checkConnected();
ByteSequence payload = wireFormat.marshal(frame);
connection.sendMessage(payload.data, 0, payload.length);
}
public void keepAlive() throws Exception {
checkConnected();
ByteSequence payload = wireFormat.marshal(new PINGREQ().encode());
connection.sendMessage(payload.data, 0, payload.length);
}
//----- Receive methods --------------------------------------------------//
public MQTTFrame receive() throws Exception {
checkConnected();
return prefetch.take();
}
public MQTTFrame receive(long timeout, TimeUnit unit) throws Exception {
checkConnected();
return prefetch.poll(timeout, unit);
}
public MQTTFrame receiveNoWait() throws Exception {
checkConnected();
return prefetch.poll();
}
//---- Blocking state change calls ---------------------------------------//
public void awaitConnection() throws InterruptedException {
connectLatch.await();
}
public boolean awaitConnection(long time, TimeUnit unit) throws InterruptedException {
return connectLatch.await(time, unit);
}
//----- Property Accessors -----------------------------------------------//
public int getCloseCode() {
return closeCode;
}
public String getCloseMessage() {
return closeMessage;
}
//----- WebSocket callback handlers --------------------------------------//
@Override
public void onMessage(byte[] data, int offset, int length) {
if (data ==null || length <= 0) {
return;
}
MQTTFrame frame = null;
try {
frame = (MQTTFrame)wireFormat.unmarshal(new ByteSequence(data, offset, length));
} catch (IOException e) {
LOG.error("Could not decode incoming MQTT Frame: ", e.getMessage());
connection.close();
}
try {
switch (frame.messageType()) {
case PINGREQ.TYPE:
PINGREQ ping = new PINGREQ().decode(frame);
LOG.info("WS-Client read frame: {}", ping);
sendFrame(PING_RESP_FRAME);
break;
case PINGRESP.TYPE:
LOG.info("WS-Client ping response received.");
break;
case CONNACK.TYPE:
CONNACK connAck = new CONNACK().decode(frame);
LOG.info("WS-Client read frame: {}", connAck);
prefetch.put(frame);
break;
case SUBACK.TYPE:
SUBACK subAck = new SUBACK().decode(frame);
LOG.info("WS-Client read frame: {}", subAck);
prefetch.put(frame);
break;
case PUBLISH.TYPE:
PUBLISH publish = new PUBLISH().decode(frame);
LOG.info("WS-Client read frame: {}", publish);
prefetch.put(frame);
break;
case PUBACK.TYPE:
PUBACK pubAck = new PUBACK().decode(frame);
LOG.info("WS-Client read frame: {}", pubAck);
prefetch.put(frame);
break;
case PUBREC.TYPE:
PUBREC pubRec = new PUBREC().decode(frame);
LOG.info("WS-Client read frame: {}", pubRec);
prefetch.put(frame);
break;
case PUBREL.TYPE:
PUBREL pubRel = new PUBREL().decode(frame);
LOG.info("WS-Client read frame: {}", pubRel);
prefetch.put(frame);
break;
case PUBCOMP.TYPE:
PUBCOMP pubComp = new PUBCOMP().decode(frame);
LOG.info("WS-Client read frame: {}", pubComp);
prefetch.put(frame);
break;
default:
LOG.error("Unknown MQTT Frame received.");
connection.close();
}
} catch (Exception e) {
LOG.error("Could not decode incoming MQTT Frame: ", e.getMessage());
connection.close();
}
}
@Override
public void onOpen(Connection connection) {
this.connection = connection;
this.connectLatch.countDown();
}
@Override
public void onClose(int closeCode, String message) {
LOG.trace("MQTT WS Connection closed, code:{} message:{}", closeCode, message);
this.connection = null;
this.closeCode = closeCode;
this.closeMessage = message;
}
//----- Internal implementation ------------------------------------------//
private void checkConnected() throws IOException {
if (!isConnected()) {
throw new IOException("MQTT WS Connection is closed.");
}
}
}

View File

@ -0,0 +1,80 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.ws;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.Vector;
import java.util.concurrent.TimeUnit;
import org.apache.activemq.util.Wait;
import org.eclipse.jetty.websocket.WebSocketClient;
import org.eclipse.jetty.websocket.WebSocketClientFactory;
import org.junit.Before;
import org.junit.Test;
public class MQTTWSConnectionTimeoutTest extends WSTransportTestSupport {
protected WebSocketClient wsClient;
protected MQTTWSConnection wsMQTTConnection;
protected Vector<Throwable> exceptions = new Vector<Throwable>();
@Override
@Before
public void setUp() throws Exception {
super.setUp();
WebSocketClientFactory clientFactory = new WebSocketClientFactory();
clientFactory.start();
wsClient = clientFactory.newWebSocketClient();
wsClient.setProtocol("mqttv3.1");
wsMQTTConnection = new MQTTWSConnection();
wsClient.open(wsConnectUri, wsMQTTConnection);
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
throw new IOException("Could not connect to MQTT WS endpoint");
}
}
protected String getConnectorScheme() {
return "ws";
}
@Test(timeout = 90000)
public void testInactivityMonitor() throws Exception {
assertTrue("one connection", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return 1 == broker.getTransportConnectorByScheme(getConnectorScheme()).connectionCount();
}
}, TimeUnit.SECONDS.toMillis(15), TimeUnit.MILLISECONDS.toMillis(250)));
// and it should be closed due to inactivity
assertTrue("no dangling connections", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return 0 == broker.getTransportConnectorByScheme(getConnectorScheme()).connectionCount();
}
}, TimeUnit.SECONDS.toMillis(60), TimeUnit.MILLISECONDS.toMillis(500)));
assertTrue("no exceptions", exceptions.isEmpty());
}
}

View File

@ -0,0 +1,130 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.ws;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import org.apache.activemq.util.Wait;
import org.eclipse.jetty.websocket.WebSocketClient;
import org.eclipse.jetty.websocket.WebSocketClientFactory;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
/**
* Test that a WS client can steal links when enabled.
*/
public class MQTTWSLinkStealingTest extends WSTransportTestSupport {
private final String CLIENT_ID = "WS-CLIENT-ID";
protected WebSocketClient wsClient;
protected MQTTWSConnection wsMQTTConnection;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
WebSocketClientFactory clientFactory = new WebSocketClientFactory();
clientFactory.start();
wsClient = clientFactory.newWebSocketClient();
wsClient.setProtocol("mqttv3.1");
wsMQTTConnection = new MQTTWSConnection();
wsClient.open(wsConnectUri, wsMQTTConnection);
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
throw new IOException("Could not connect to MQTT WS endpoint");
}
}
@Override
@After
public void tearDown() throws Exception {
if (wsMQTTConnection != null) {
wsMQTTConnection.close();
wsMQTTConnection = null;
wsClient = null;
}
super.tearDown();
}
@Test(timeout = 60000)
public void testConnectAndStealLink() throws Exception {
wsMQTTConnection.connect(CLIENT_ID);
assertTrue("Connection should open", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 1;
}
}));
WebSocketClientFactory theifFactory = new WebSocketClientFactory();
theifFactory.start();
MQTTWSConnection theif = new MQTTWSConnection();
wsClient.open(wsConnectUri, theif);
if (!theif.awaitConnection(30, TimeUnit.SECONDS)) {
fail("Could not open new WS connection for link stealing client");
}
theif.connect(CLIENT_ID);
assertTrue("Connection should open", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 1;
}
}));
assertTrue("Original Connection should close", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return !wsMQTTConnection.isConnected();
}
}));
theif.disconnect();
theif.close();
assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 0;
}
}));
}
@Override
protected boolean isAllowLinkStealing() {
return true;
}
}

View File

@ -0,0 +1,216 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.ws;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.IOException;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.activemq.util.Wait;
import org.eclipse.jetty.websocket.WebSocketClient;
import org.eclipse.jetty.websocket.WebSocketClientFactory;
import org.fusesource.hawtbuf.UTF8Buffer;
import org.fusesource.mqtt.codec.CONNACK;
import org.fusesource.mqtt.codec.CONNECT;
import org.fusesource.mqtt.codec.MQTTFrame;
import org.fusesource.mqtt.codec.PINGREQ;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
public class MQTTWSTransportTest extends WSTransportTestSupport {
protected WebSocketClient wsClient;
protected MQTTWSConnection wsMQTTConnection;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
WebSocketClientFactory clientFactory = new WebSocketClientFactory();
clientFactory.start();
wsClient = clientFactory.newWebSocketClient();
wsClient.setProtocol("mqttv3.1");
wsMQTTConnection = new MQTTWSConnection();
wsClient.open(wsConnectUri, wsMQTTConnection);
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
throw new IOException("Could not connect to MQTT WS endpoint");
}
}
@Override
@After
public void tearDown() throws Exception {
if (wsMQTTConnection != null) {
wsMQTTConnection.close();
wsMQTTConnection = null;
wsClient = null;
}
super.tearDown();
}
@Test(timeout = 60000)
public void testConnect() throws Exception {
wsMQTTConnection.connect();
assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 1;
}
}));
wsMQTTConnection.disconnect();
wsMQTTConnection.close();
assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 0;
}
}));
}
@Test(timeout = 60000)
public void testConnectWithNoHeartbeatsClosesConnection() throws Exception {
CONNECT command = new CONNECT();
command.clientId(new UTF8Buffer(UUID.randomUUID().toString()));
command.cleanSession(false);
command.version(3);
command.keepAlive((short) 2);
wsMQTTConnection.sendFrame(command.encode());
MQTTFrame received = wsMQTTConnection.receive(15, TimeUnit.SECONDS);
if (received == null || received.messageType() != CONNACK.TYPE) {
fail("Client did not get expected CONNACK");
}
assertTrue("Connection should open", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 1;
}
}));
assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 0;
}
}));
assertTrue("Client Connection should close", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return !wsMQTTConnection.isConnected();
}
}));
}
@Test(timeout = 60000)
public void testConnectWithHeartbeatsKeepsConnectionAlive() throws Exception {
final AtomicBoolean done = new AtomicBoolean();
CONNECT command = new CONNECT();
command.clientId(new UTF8Buffer(UUID.randomUUID().toString()));
command.cleanSession(false);
command.version(3);
command.keepAlive((short) 2);
wsMQTTConnection.sendFrame(command.encode());
MQTTFrame received = wsMQTTConnection.receive(15, TimeUnit.SECONDS);
if (received == null || received.messageType() != CONNACK.TYPE) {
fail("Client did not get expected CONNACK");
}
Thread pinger = new Thread(new Runnable() {
@Override
public void run() {
try {
while (!done.get()) {
TimeUnit.SECONDS.sleep(1);
wsMQTTConnection.sendFrame(new PINGREQ().encode());
}
} catch (Exception e) {
}
}
});
pinger.start();
assertTrue("Connection should open", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 1;
}
}));
TimeUnit.SECONDS.sleep(10);
assertTrue("Connection should still open", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 1;
}
}));
wsMQTTConnection.disconnect();
wsMQTTConnection.close();
done.set(true);
assertTrue("Connection should close", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return getProxyToBroker().getCurrentConnectionsCount() == 0;
}
}));
assertTrue("Client Connection should close", Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisified() throws Exception {
return !wsMQTTConnection.isConnected();
}
}));
}
}

View File

@ -1,3 +1,19 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.ws;
import static org.junit.Assert.assertEquals;
@ -33,5 +49,4 @@ public class SocketTest {
assertEquals("ws://localhost:8080", mqttSocketJetty9.getRemoteAddress());
}
}

View File

@ -70,7 +70,14 @@ public class WSTransportTestSupport {
}
protected String getWSConnectorURI() {
return "ws://127.0.0.1:" + getProxyPort() + "?websocket.maxTextMessageSize=99999&transport.maxIdleTime=1001";
return "ws://127.0.0.1:" + getProxyPort() +
"?allowLinkStealing=" + isAllowLinkStealing() +
"&websocket.maxTextMessageSize=99999&" +
"transport.maxIdleTime=1001";
}
protected boolean isAllowLinkStealing() {
return false;
}
protected void addAdditionalConnectors(BrokerService service) throws Exception {

View File

@ -221,7 +221,7 @@ public class MQTTInactivityMonitor extends TransportFilter {
return protocolConverter;
}
synchronized void startConnectChecker(long connectionTimeout) {
public synchronized void startConnectChecker(long connectionTimeout) {
this.connectionTimeout = connectionTimeout;
if (connectionTimeout > 0 && connectCheckerTask == null) {
connectCheckerTask = new SchedulerTimerTask(connectChecker);

View File

@ -60,7 +60,6 @@ public class MQTTTransportFilter extends TransportFilter implements MQTTTranspor
private MQTTInactivityMonitor monitor;
private MQTTWireFormat wireFormat;
private final AtomicBoolean stopped = new AtomicBoolean();
private long connectAttemptTimeout = MQTTWireFormat.DEFAULT_CONNECTION_TIMEOUT;
private boolean trace;
private final Object sendLock = new Object();
@ -216,7 +215,7 @@ public class MQTTTransportFilter extends TransportFilter implements MQTTTranspor
* @return the timeout value used to fail a connection if no CONNECT frame read.
*/
public long getConnectAttemptTimeout() {
return connectAttemptTimeout;
return wireFormat.getConnectAttemptTimeout();
}
/**
@ -227,7 +226,7 @@ public class MQTTTransportFilter extends TransportFilter implements MQTTTranspor
* the connection frame received timeout value.
*/
public void setConnectAttemptTimeout(long connectTimeout) {
this.connectAttemptTimeout = connectTimeout;
this.setConnectAttemptTimeout(connectTimeout);
}
public boolean getPublishDollarTopics() {

View File

@ -41,6 +41,7 @@ public class MQTTWireFormat implements WireFormat {
private int version = 1;
private int maxFrameSize = MAX_MESSAGE_LENGTH;
private long connectAttemptTimeout = MQTTWireFormat.DEFAULT_CONNECTION_TIMEOUT;
@Override
public ByteSequence marshal(Object command) throws IOException {
@ -144,4 +145,22 @@ public class MQTTWireFormat implements WireFormat {
public void setMaxFrameSize(int maxFrameSize) {
this.maxFrameSize = Math.min(MAX_MESSAGE_LENGTH, maxFrameSize);
}
/**
* @return the timeout value used to fail a connection if no CONNECT frame read.
*/
public long getConnectAttemptTimeout() {
return connectAttemptTimeout;
}
/**
* Sets the timeout value used to fail a connection if no CONNECT frame is read
* in the given interval.
*
* @param connectTimeout
* the connection frame received timeout value.
*/
public void setConnectAttemptTimeout(long connectTimeout) {
this.connectAttemptTimeout = connectTimeout;
}
}