Fix handling of incoming MQTT binary data over WS.  The handler should
use the MQTTCodec to ensure that partial or packed frames are fully
processed
(cherry picked from commit e69367fbc3)
This commit is contained in:
Timothy Bish 2017-05-22 12:27:16 -04:00
parent bc879d762a
commit bf395fcdb3
4 changed files with 94 additions and 20 deletions

View File

@ -21,26 +21,33 @@ import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.activemq.transport.mqtt.MQTTCodec;
import org.apache.activemq.transport.ws.AbstractMQTTSocket;
import org.apache.activemq.util.ByteSequence;
import org.apache.activemq.util.IOExceptionSupport;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.fusesource.hawtbuf.Buffer;
import org.fusesource.hawtbuf.DataByteArrayInputStream;
import org.fusesource.mqtt.codec.DISCONNECT;
import org.fusesource.mqtt.codec.MQTTFrame;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener {
public class MQTTSocket extends AbstractMQTTSocket implements MQTTCodec.MQTTFrameSink, WebSocketListener {
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
private final int ORDERLY_CLOSE_TIMEOUT = 10;
private Session session;
final AtomicBoolean receivedDisconnect = new AtomicBoolean();
private final AtomicBoolean receivedDisconnect = new AtomicBoolean();
private final MQTTCodec codec;
public MQTTSocket(String remoteAddress) {
super(remoteAddress);
this.codec = new MQTTCodec(this, getWireFormat());
}
@Override
@ -78,11 +85,7 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener
protocolLock.lock();
try {
receiveCounter += length;
MQTTFrame frame = (MQTTFrame)wireFormat.unmarshal(new ByteSequence(bytes, offset, length));
if (frame.messageType() == DISCONNECT.TYPE) {
receivedDisconnect.set(true);
}
getProtocolConverter().onMQTTCommand(frame);
codec.parse(new DataByteArrayInputStream(new Buffer(bytes, offset, length)), length);
} catch (Exception e) {
onException(IOExceptionSupport.create(e));
} finally {
@ -127,4 +130,18 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener
private static int getDefaultSendTimeOut() {
return Integer.getInteger("org.apache.activemq.transport.ws.MQTTSocket.sendTimeout", 30);
}
//----- MQTTCodec Frame Sink event point ---------------------------------//
@Override
public void onFrame(MQTTFrame mqttFrame) {
try {
if (mqttFrame.messageType() == DISCONNECT.TYPE) {
receivedDisconnect.set(true);
}
getProtocolConverter().onMQTTCommand(mqttFrame);
} catch (Exception e) {
onException(IOExceptionSupport.create(e));
}
}
}

View File

@ -58,8 +58,9 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
private final CountDownLatch connectLatch = new CountDownLatch(1);
private final MQTTWireFormat wireFormat = new MQTTWireFormat();
private final BlockingQueue<MQTTFrame> prefetch = new LinkedBlockingDeque<MQTTFrame>();
private final BlockingQueue<MQTTFrame> prefetch = new LinkedBlockingDeque<>();
private boolean writePartialFrames;
private int closeCode = -1;
private String closeMessage;
@ -96,8 +97,7 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
public void connect(CONNECT command) throws Exception {
checkConnected();
ByteSequence payload = wireFormat.marshal(command.encode());
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data));
sendBytes(wireFormat.marshal(command.encode()));
MQTTFrame incoming = receive(15, TimeUnit.SECONDS);
@ -117,22 +117,19 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
}
DISCONNECT command = new DISCONNECT();
ByteSequence payload = wireFormat.marshal(command.encode());
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data));
sendBytes(wireFormat.marshal(command.encode()));
}
//---- Send methods ------------------------------------------------------//
public void sendFrame(MQTTFrame frame) throws Exception {
checkConnected();
ByteSequence payload = wireFormat.marshal(frame);
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data));
sendBytes(wireFormat.marshal(frame));
}
public void keepAlive() throws Exception {
checkConnected();
ByteSequence payload = wireFormat.marshal(new PINGREQ().encode());
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data));
sendBytes(wireFormat.marshal(new PINGREQ().encode()));
}
//----- Receive methods --------------------------------------------------//
@ -172,6 +169,15 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
return closeMessage;
}
public boolean isWritePartialFrames() {
return writePartialFrames;
}
public MQTTWSConnection setWritePartialFrames(boolean value) {
this.writePartialFrames = value;
return this;
}
//----- WebSocket callback handlers --------------------------------------//
@Override
@ -246,6 +252,17 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
//----- Internal implementation ------------------------------------------//
private void sendBytes(ByteSequence payload) throws IOException {
if (!isWritePartialFrames()) {
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data, payload.offset, payload.length));
} else {
connection.getRemote().sendBytes(ByteBuffer.wrap(
payload.data, payload.offset, payload.length / 2));
connection.getRemote().sendBytes(ByteBuffer.wrap(
payload.data, payload.offset + payload.length / 2, payload.length / 2));
}
}
private void checkConnected() throws IOException {
if (!isConnected()) {
throw new IOException("MQTT WS Connection is closed.");

View File

@ -1,4 +1,4 @@
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
@ -16,8 +16,28 @@
*/
package org.apache.activemq.transport.ws;
import java.util.Arrays;
import java.util.Collection;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
@RunWith(Parameterized.class)
public class MQTTWSSTransportTest extends MQTTWSTransportTest {
@Parameters(name="{0}")
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
{"complete-frames", false},
{"partial-frames", false}
});
}
public MQTTWSSTransportTest(String testName, boolean partialFrames) {
super(testName, partialFrames);
}
@Override
protected String getWSConnectorURI() {
return "wss://localhost:61623";

View File

@ -1,4 +1,4 @@
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
@ -20,6 +20,8 @@ import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
@ -36,13 +38,31 @@ import org.fusesource.mqtt.codec.PINGREQ;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
@RunWith(Parameterized.class)
public class MQTTWSTransportTest extends WSTransportTestSupport {
protected WebSocketClient wsClient;
protected MQTTWSConnection wsMQTTConnection;
protected ClientUpgradeRequest request;
protected boolean partialFrames;
@Parameters(name="{0}")
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
{"complete-frames", false},
{"partial-frames", true}
});
}
public MQTTWSTransportTest(String testName, boolean partialFrames) {
this.partialFrames = partialFrames;
}
@Override
@Before
public void setUp() throws Exception {
@ -54,7 +74,7 @@ public class MQTTWSTransportTest extends WSTransportTestSupport {
request = new ClientUpgradeRequest();
request.setSubProtocols("mqttv3.1");
wsMQTTConnection = new MQTTWSConnection();
wsMQTTConnection = new MQTTWSConnection().setWritePartialFrames(partialFrames);
wsClient.connect(wsMQTTConnection, wsConnectUri, request);
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
@ -79,7 +99,7 @@ public class MQTTWSTransportTest extends WSTransportTestSupport {
for (int i = 0; i < 10; ++i) {
testConnect();
wsMQTTConnection = new MQTTWSConnection();
wsMQTTConnection = new MQTTWSConnection().setWritePartialFrames(partialFrames);
wsClient.connect(wsMQTTConnection, wsConnectUri, request);
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {