mirror of https://github.com/apache/activemq.git
Fix handling of incoming MQTT binary data over WS. The handler should
use the MQTTCodec to ensure that partial or packed frames are fully
processed
(cherry picked from commit e69367fbc3
)
This commit is contained in:
parent
bc879d762a
commit
bf395fcdb3
|
@ -21,26 +21,33 @@ import java.nio.ByteBuffer;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import org.apache.activemq.transport.mqtt.MQTTCodec;
|
||||
import org.apache.activemq.transport.ws.AbstractMQTTSocket;
|
||||
import org.apache.activemq.util.ByteSequence;
|
||||
import org.apache.activemq.util.IOExceptionSupport;
|
||||
import org.eclipse.jetty.websocket.api.Session;
|
||||
import org.eclipse.jetty.websocket.api.WebSocketListener;
|
||||
import org.fusesource.hawtbuf.Buffer;
|
||||
import org.fusesource.hawtbuf.DataByteArrayInputStream;
|
||||
import org.fusesource.mqtt.codec.DISCONNECT;
|
||||
import org.fusesource.mqtt.codec.MQTTFrame;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener {
|
||||
public class MQTTSocket extends AbstractMQTTSocket implements MQTTCodec.MQTTFrameSink, WebSocketListener {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class);
|
||||
|
||||
private final int ORDERLY_CLOSE_TIMEOUT = 10;
|
||||
private Session session;
|
||||
final AtomicBoolean receivedDisconnect = new AtomicBoolean();
|
||||
private final AtomicBoolean receivedDisconnect = new AtomicBoolean();
|
||||
|
||||
private final MQTTCodec codec;
|
||||
|
||||
public MQTTSocket(String remoteAddress) {
|
||||
super(remoteAddress);
|
||||
|
||||
this.codec = new MQTTCodec(this, getWireFormat());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -78,11 +85,7 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener
|
|||
protocolLock.lock();
|
||||
try {
|
||||
receiveCounter += length;
|
||||
MQTTFrame frame = (MQTTFrame)wireFormat.unmarshal(new ByteSequence(bytes, offset, length));
|
||||
if (frame.messageType() == DISCONNECT.TYPE) {
|
||||
receivedDisconnect.set(true);
|
||||
}
|
||||
getProtocolConverter().onMQTTCommand(frame);
|
||||
codec.parse(new DataByteArrayInputStream(new Buffer(bytes, offset, length)), length);
|
||||
} catch (Exception e) {
|
||||
onException(IOExceptionSupport.create(e));
|
||||
} finally {
|
||||
|
@ -127,4 +130,18 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener
|
|||
private static int getDefaultSendTimeOut() {
|
||||
return Integer.getInteger("org.apache.activemq.transport.ws.MQTTSocket.sendTimeout", 30);
|
||||
}
|
||||
|
||||
//----- MQTTCodec Frame Sink event point ---------------------------------//
|
||||
|
||||
@Override
|
||||
public void onFrame(MQTTFrame mqttFrame) {
|
||||
try {
|
||||
if (mqttFrame.messageType() == DISCONNECT.TYPE) {
|
||||
receivedDisconnect.set(true);
|
||||
}
|
||||
getProtocolConverter().onMQTTCommand(mqttFrame);
|
||||
} catch (Exception e) {
|
||||
onException(IOExceptionSupport.create(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,8 +58,9 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
|
|||
private final CountDownLatch connectLatch = new CountDownLatch(1);
|
||||
private final MQTTWireFormat wireFormat = new MQTTWireFormat();
|
||||
|
||||
private final BlockingQueue<MQTTFrame> prefetch = new LinkedBlockingDeque<MQTTFrame>();
|
||||
private final BlockingQueue<MQTTFrame> prefetch = new LinkedBlockingDeque<>();
|
||||
|
||||
private boolean writePartialFrames;
|
||||
private int closeCode = -1;
|
||||
private String closeMessage;
|
||||
|
||||
|
@ -96,8 +97,7 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
|
|||
public void connect(CONNECT command) throws Exception {
|
||||
checkConnected();
|
||||
|
||||
ByteSequence payload = wireFormat.marshal(command.encode());
|
||||
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data));
|
||||
sendBytes(wireFormat.marshal(command.encode()));
|
||||
|
||||
MQTTFrame incoming = receive(15, TimeUnit.SECONDS);
|
||||
|
||||
|
@ -117,22 +117,19 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
|
|||
}
|
||||
|
||||
DISCONNECT command = new DISCONNECT();
|
||||
ByteSequence payload = wireFormat.marshal(command.encode());
|
||||
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data));
|
||||
sendBytes(wireFormat.marshal(command.encode()));
|
||||
}
|
||||
|
||||
//---- Send methods ------------------------------------------------------//
|
||||
|
||||
public void sendFrame(MQTTFrame frame) throws Exception {
|
||||
checkConnected();
|
||||
ByteSequence payload = wireFormat.marshal(frame);
|
||||
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data));
|
||||
sendBytes(wireFormat.marshal(frame));
|
||||
}
|
||||
|
||||
public void keepAlive() throws Exception {
|
||||
checkConnected();
|
||||
ByteSequence payload = wireFormat.marshal(new PINGREQ().encode());
|
||||
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data));
|
||||
sendBytes(wireFormat.marshal(new PINGREQ().encode()));
|
||||
}
|
||||
|
||||
//----- Receive methods --------------------------------------------------//
|
||||
|
@ -172,6 +169,15 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
|
|||
return closeMessage;
|
||||
}
|
||||
|
||||
public boolean isWritePartialFrames() {
|
||||
return writePartialFrames;
|
||||
}
|
||||
|
||||
public MQTTWSConnection setWritePartialFrames(boolean value) {
|
||||
this.writePartialFrames = value;
|
||||
return this;
|
||||
}
|
||||
|
||||
//----- WebSocket callback handlers --------------------------------------//
|
||||
|
||||
@Override
|
||||
|
@ -246,6 +252,17 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe
|
|||
|
||||
//----- Internal implementation ------------------------------------------//
|
||||
|
||||
private void sendBytes(ByteSequence payload) throws IOException {
|
||||
if (!isWritePartialFrames()) {
|
||||
connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data, payload.offset, payload.length));
|
||||
} else {
|
||||
connection.getRemote().sendBytes(ByteBuffer.wrap(
|
||||
payload.data, payload.offset, payload.length / 2));
|
||||
connection.getRemote().sendBytes(ByteBuffer.wrap(
|
||||
payload.data, payload.offset + payload.length / 2, payload.length / 2));
|
||||
}
|
||||
}
|
||||
|
||||
private void checkConnected() throws IOException {
|
||||
if (!isConnected()) {
|
||||
throw new IOException("MQTT WS Connection is closed.");
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/**
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
|
@ -16,8 +16,28 @@
|
|||
*/
|
||||
package org.apache.activemq.transport.ws;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.runners.Parameterized.Parameters;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class MQTTWSSTransportTest extends MQTTWSTransportTest {
|
||||
|
||||
@Parameters(name="{0}")
|
||||
public static Collection<Object[]> data() {
|
||||
return Arrays.asList(new Object[][] {
|
||||
{"complete-frames", false},
|
||||
{"partial-frames", false}
|
||||
});
|
||||
}
|
||||
|
||||
public MQTTWSSTransportTest(String testName, boolean partialFrames) {
|
||||
super(testName, partialFrames);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getWSConnectorURI() {
|
||||
return "wss://localhost:61623";
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/**
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
|
@ -20,6 +20,8 @@ import static org.junit.Assert.assertTrue;
|
|||
import static org.junit.Assert.fail;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
@ -36,13 +38,31 @@ import org.fusesource.mqtt.codec.PINGREQ;
|
|||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.runners.Parameterized.Parameters;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class MQTTWSTransportTest extends WSTransportTestSupport {
|
||||
|
||||
protected WebSocketClient wsClient;
|
||||
protected MQTTWSConnection wsMQTTConnection;
|
||||
protected ClientUpgradeRequest request;
|
||||
|
||||
protected boolean partialFrames;
|
||||
|
||||
@Parameters(name="{0}")
|
||||
public static Collection<Object[]> data() {
|
||||
return Arrays.asList(new Object[][] {
|
||||
{"complete-frames", false},
|
||||
{"partial-frames", true}
|
||||
});
|
||||
}
|
||||
|
||||
public MQTTWSTransportTest(String testName, boolean partialFrames) {
|
||||
this.partialFrames = partialFrames;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
|
@ -54,7 +74,7 @@ public class MQTTWSTransportTest extends WSTransportTestSupport {
|
|||
request = new ClientUpgradeRequest();
|
||||
request.setSubProtocols("mqttv3.1");
|
||||
|
||||
wsMQTTConnection = new MQTTWSConnection();
|
||||
wsMQTTConnection = new MQTTWSConnection().setWritePartialFrames(partialFrames);
|
||||
|
||||
wsClient.connect(wsMQTTConnection, wsConnectUri, request);
|
||||
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
|
||||
|
@ -79,7 +99,7 @@ public class MQTTWSTransportTest extends WSTransportTestSupport {
|
|||
for (int i = 0; i < 10; ++i) {
|
||||
testConnect();
|
||||
|
||||
wsMQTTConnection = new MQTTWSConnection();
|
||||
wsMQTTConnection = new MQTTWSConnection().setWritePartialFrames(partialFrames);
|
||||
|
||||
wsClient.connect(wsMQTTConnection, wsConnectUri, request);
|
||||
if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {
|
||||
|
|
Loading…
Reference in New Issue