mirror of https://github.com/apache/activemq.git
Fixed the MQTTCodec to properly handle frames that come in split up or bunched together.
This commit is contained in:
parent
9743dbddb6
commit
7c04ead460
|
@ -18,144 +18,165 @@ package org.apache.activemq.transport.mqtt;
|
|||
|
||||
import java.io.IOException;
|
||||
|
||||
import javax.jms.JMSException;
|
||||
import org.apache.activemq.transport.tcp.TcpTransport;
|
||||
import org.fusesource.hawtbuf.DataByteArrayInputStream;
|
||||
import org.fusesource.hawtbuf.DataByteArrayOutputStream;
|
||||
import org.fusesource.mqtt.codec.*;
|
||||
import org.fusesource.mqtt.codec.MQTTFrame;
|
||||
|
||||
public class MQTTCodec {
|
||||
|
||||
TcpTransport transport;
|
||||
private final MQTTFrameSink frameSink;
|
||||
private final DataByteArrayOutputStream currentCommand = new DataByteArrayOutputStream();
|
||||
private byte header;
|
||||
|
||||
DataByteArrayOutputStream currentCommand = new DataByteArrayOutputStream();
|
||||
boolean processedHeader = false;
|
||||
String action;
|
||||
byte header;
|
||||
int contentLength = -1;
|
||||
int previousByte = -1;
|
||||
int payLoadRead = 0;
|
||||
private int contentLength = -1;
|
||||
private int payLoadRead = 0;
|
||||
|
||||
public MQTTCodec(TcpTransport transport) {
|
||||
this.transport = transport;
|
||||
public interface MQTTFrameSink {
|
||||
void onFrame(MQTTFrame mqttFrame);
|
||||
}
|
||||
|
||||
private FrameParser currentParser;
|
||||
|
||||
// Internal parsers implement this and we switch to the next as we go.
|
||||
private interface FrameParser {
|
||||
|
||||
void parse(DataByteArrayInputStream data, int readSize) throws IOException;
|
||||
|
||||
void reset() throws IOException;
|
||||
}
|
||||
|
||||
public MQTTCodec(MQTTFrameSink sink) {
|
||||
this.frameSink = sink;
|
||||
}
|
||||
|
||||
public MQTTCodec(final TcpTransport transport) {
|
||||
this.frameSink = new MQTTFrameSink() {
|
||||
|
||||
@Override
|
||||
public void onFrame(MQTTFrame mqttFrame) {
|
||||
transport.doConsume(mqttFrame);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public void parse(DataByteArrayInputStream input, int readSize) throws Exception {
|
||||
int i = 0;
|
||||
byte b;
|
||||
while (i++ < readSize) {
|
||||
b = input.readByte();
|
||||
// skip repeating nulls
|
||||
if (!processedHeader && b == 0) {
|
||||
previousByte = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!processedHeader) {
|
||||
i += processHeader(b, input);
|
||||
if (contentLength == 0) {
|
||||
processCommand();
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
if (contentLength == -1) {
|
||||
// end of command reached, unmarshal
|
||||
if (b == 0) {
|
||||
processCommand();
|
||||
} else {
|
||||
currentCommand.write(b);
|
||||
}
|
||||
} else {
|
||||
// read desired content length
|
||||
if (payLoadRead == contentLength) {
|
||||
processCommand();
|
||||
i += processHeader(b, input);
|
||||
} else {
|
||||
currentCommand.write(b);
|
||||
payLoadRead++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
previousByte = b;
|
||||
}
|
||||
if (processedHeader && payLoadRead == contentLength) {
|
||||
processCommand();
|
||||
if (currentParser == null) {
|
||||
currentParser = initializeHeaderParser();
|
||||
}
|
||||
|
||||
// Parser stack will run until current incoming data has all been consumed.
|
||||
currentParser.parse(input, readSize);
|
||||
}
|
||||
|
||||
/**
|
||||
* sets the content length
|
||||
*
|
||||
* @return number of bytes read
|
||||
*/
|
||||
private int processHeader(byte header, DataByteArrayInputStream input) {
|
||||
this.header = header;
|
||||
byte digit;
|
||||
int multiplier = 1;
|
||||
int read = 0;
|
||||
int length = 0;
|
||||
do {
|
||||
digit = input.readByte();
|
||||
length += (digit & 0x7F) * multiplier;
|
||||
multiplier <<= 7;
|
||||
read++;
|
||||
} while ((digit & 0x80) != 0);
|
||||
|
||||
contentLength = length;
|
||||
processedHeader = true;
|
||||
return read;
|
||||
}
|
||||
|
||||
|
||||
private void processCommand() throws Exception {
|
||||
private void processCommand() throws IOException {
|
||||
MQTTFrame frame = new MQTTFrame(currentCommand.toBuffer().deepCopy()).header(header);
|
||||
transport.doConsume(frame);
|
||||
processedHeader = false;
|
||||
currentCommand.reset();
|
||||
contentLength = -1;
|
||||
payLoadRead = 0;
|
||||
frameSink.onFrame(frame);
|
||||
}
|
||||
|
||||
public static String commandType(byte header) throws IOException, JMSException {
|
||||
//----- Prepare the current frame parser for use -------------------------//
|
||||
|
||||
byte messageType = (byte) ((header & 0xF0) >>> 4);
|
||||
switch (messageType) {
|
||||
case PINGREQ.TYPE: {
|
||||
return "PINGREQ";
|
||||
private FrameParser initializeHeaderParser() throws IOException {
|
||||
headerParser.reset();
|
||||
return headerParser;
|
||||
}
|
||||
|
||||
private FrameParser initializeVariableLengthParser() throws IOException {
|
||||
variableLengthParser.reset();
|
||||
return variableLengthParser;
|
||||
}
|
||||
|
||||
private FrameParser initializeContentParser() throws IOException {
|
||||
contentParser.reset();
|
||||
return contentParser;
|
||||
}
|
||||
|
||||
//----- Frame parser implementations -------------------------------------//
|
||||
|
||||
private final FrameParser headerParser = new FrameParser() {
|
||||
|
||||
@Override
|
||||
public void parse(DataByteArrayInputStream data, int readSize) throws IOException {
|
||||
int i = 0;
|
||||
while (i++ < readSize) {
|
||||
byte b = data.readByte();
|
||||
// skip repeating nulls
|
||||
if (b == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
header = b;
|
||||
|
||||
currentParser = initializeVariableLengthParser();
|
||||
currentParser.parse(data, readSize - 1);
|
||||
return;
|
||||
}
|
||||
case CONNECT.TYPE: {
|
||||
return "CONNECT";
|
||||
}
|
||||
case DISCONNECT.TYPE: {
|
||||
return "DISCONNECT";
|
||||
}
|
||||
case SUBSCRIBE.TYPE: {
|
||||
return "SUBSCRIBE";
|
||||
}
|
||||
case UNSUBSCRIBE.TYPE: {
|
||||
return "UNSUBSCRIBE";
|
||||
}
|
||||
case PUBLISH.TYPE: {
|
||||
return "PUBLISH";
|
||||
}
|
||||
case PUBACK.TYPE: {
|
||||
return "PUBACK";
|
||||
}
|
||||
case PUBREC.TYPE: {
|
||||
return "PUBREC";
|
||||
}
|
||||
case PUBREL.TYPE: {
|
||||
return "PUBREL";
|
||||
}
|
||||
case PUBCOMP.TYPE: {
|
||||
return "PUBCOMP";
|
||||
}
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
||||
}
|
||||
@Override
|
||||
public void reset() throws IOException {
|
||||
header = -1;
|
||||
}
|
||||
};
|
||||
|
||||
private final FrameParser contentParser = new FrameParser() {
|
||||
|
||||
@Override
|
||||
public void parse(DataByteArrayInputStream data, int readSize) throws IOException {
|
||||
int i = 0;
|
||||
while (i++ < readSize) {
|
||||
currentCommand.write(data.readByte());
|
||||
payLoadRead++;
|
||||
|
||||
if (payLoadRead == contentLength) {
|
||||
processCommand();
|
||||
currentParser = initializeHeaderParser();
|
||||
currentParser.parse(data, readSize - i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() throws IOException {
|
||||
contentLength = -1;
|
||||
payLoadRead = 0;
|
||||
currentCommand.reset();
|
||||
}
|
||||
};
|
||||
|
||||
private final FrameParser variableLengthParser = new FrameParser() {
|
||||
|
||||
private byte digit;
|
||||
private int multiplier = 1;
|
||||
private int length;
|
||||
|
||||
@Override
|
||||
public void parse(DataByteArrayInputStream data, int readSize) throws IOException {
|
||||
int i = 0;
|
||||
while (i++ < readSize) {
|
||||
digit = data.readByte();
|
||||
length += (digit & 0x7F) * multiplier;
|
||||
multiplier <<= 7;
|
||||
if ((digit & 0x80) == 0) {
|
||||
if (length == 0) {
|
||||
processCommand();
|
||||
currentParser = initializeHeaderParser();
|
||||
} else {
|
||||
currentParser = initializeContentParser();
|
||||
contentLength = length;
|
||||
}
|
||||
currentParser.parse(data, readSize - i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() throws IOException {
|
||||
digit = 0;
|
||||
multiplier = 1;
|
||||
length = 0;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -23,13 +23,14 @@ import java.net.UnknownHostException;
|
|||
import java.nio.ByteBuffer;
|
||||
|
||||
import javax.net.SocketFactory;
|
||||
|
||||
import org.apache.activemq.transport.nio.NIOSSLTransport;
|
||||
import org.apache.activemq.wireformat.WireFormat;
|
||||
import org.fusesource.hawtbuf.DataByteArrayInputStream;
|
||||
|
||||
public class MQTTNIOSSLTransport extends NIOSSLTransport {
|
||||
|
||||
MQTTCodec codec;
|
||||
private MQTTCodec codec;
|
||||
|
||||
public MQTTNIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
|
||||
super(wireFormat, socketFactory, remoteLocation, localLocation);
|
||||
|
@ -55,5 +56,4 @@ public class MQTTNIOSSLTransport extends NIOSSLTransport {
|
|||
DataByteArrayInputStream dis = new DataByteArrayInputStream(fill);
|
||||
codec.parse(dis, fill.length);
|
||||
}
|
||||
|
||||
}
|
|
@ -16,6 +16,17 @@
|
|||
*/
|
||||
package org.apache.activemq.transport.mqtt;
|
||||
|
||||
import org.fusesource.mqtt.codec.CONNECT;
|
||||
import org.fusesource.mqtt.codec.DISCONNECT;
|
||||
import org.fusesource.mqtt.codec.PINGREQ;
|
||||
import org.fusesource.mqtt.codec.PUBACK;
|
||||
import org.fusesource.mqtt.codec.PUBCOMP;
|
||||
import org.fusesource.mqtt.codec.PUBLISH;
|
||||
import org.fusesource.mqtt.codec.PUBREC;
|
||||
import org.fusesource.mqtt.codec.PUBREL;
|
||||
import org.fusesource.mqtt.codec.SUBSCRIBE;
|
||||
import org.fusesource.mqtt.codec.UNSUBSCRIBE;
|
||||
|
||||
/**
|
||||
* A set of static methods useful for handling MQTT based client connections.
|
||||
*/
|
||||
|
@ -70,4 +81,41 @@ public class MQTTProtocolSupport {
|
|||
public static String convertActiveMQToMQTT(String destinationName) {
|
||||
return destinationName.replace('.', '/');
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an MQTT header byte, determine the command type that the header
|
||||
* represents.
|
||||
*
|
||||
* @param header
|
||||
* the byte value for the MQTT frame header.
|
||||
*
|
||||
* @return a string value for the given command type.
|
||||
*/
|
||||
public static String commandType(byte header) {
|
||||
byte messageType = (byte) ((header & 0xF0) >>> 4);
|
||||
switch (messageType) {
|
||||
case PINGREQ.TYPE:
|
||||
return "PINGREQ";
|
||||
case CONNECT.TYPE:
|
||||
return "CONNECT";
|
||||
case DISCONNECT.TYPE:
|
||||
return "DISCONNECT";
|
||||
case SUBSCRIBE.TYPE:
|
||||
return "SUBSCRIBE";
|
||||
case UNSUBSCRIBE.TYPE:
|
||||
return "UNSUBSCRIBE";
|
||||
case PUBLISH.TYPE:
|
||||
return "PUBLISH";
|
||||
case PUBACK.TYPE:
|
||||
return "PUBACK";
|
||||
case PUBREC.TYPE:
|
||||
return "PUBREC";
|
||||
case PUBREL.TYPE:
|
||||
return "PUBREL";
|
||||
case PUBCOMP.TYPE:
|
||||
return "PUBCOMP";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,14 @@ import org.apache.activemq.security.TempDestinationAuthorizationEntry;
|
|||
*/
|
||||
public class MQTTAuthTestSupport extends MQTTTestSupport {
|
||||
|
||||
public MQTTAuthTestSupport() {
|
||||
super();
|
||||
}
|
||||
|
||||
public MQTTAuthTestSupport(String connectorScheme, boolean useSSL) {
|
||||
super(connectorScheme, useSSL);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected BrokerPlugin configureAuthentication() throws Exception {
|
||||
List<AuthenticationUser> users = new ArrayList<AuthenticationUser>();
|
||||
|
|
|
@ -66,11 +66,15 @@ public class MQTTAuthTests extends MQTTAuthTestSupport {
|
|||
return Arrays.asList(new Object[][] {
|
||||
{"mqtt", false},
|
||||
{"mqtt+ssl", true},
|
||||
{"mqtt+nio", false}
|
||||
// TODO - Fails {"mqtt+nio+ssl", true}
|
||||
{"mqtt+nio", false},
|
||||
{"mqtt+nio+ssl", true}
|
||||
});
|
||||
}
|
||||
|
||||
public MQTTAuthTests(String connectorScheme, boolean useSSL) {
|
||||
super(connectorScheme, useSSL);
|
||||
}
|
||||
|
||||
@Test(timeout = 60 * 1000)
|
||||
public void testAnonymousUserConnect() throws Exception {
|
||||
MQTT mqtt = createMQTTConnection();
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.activemq.transport.mqtt;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.fusesource.hawtbuf.Buffer;
|
||||
import org.fusesource.hawtbuf.DataByteArrayInputStream;
|
||||
import org.fusesource.hawtbuf.DataByteArrayOutputStream;
|
||||
import org.fusesource.hawtbuf.UTF8Buffer;
|
||||
import org.fusesource.mqtt.codec.CONNECT;
|
||||
import org.fusesource.mqtt.codec.MQTTFrame;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Tests the functionality of the MQTTCodec class.
|
||||
*/
|
||||
public class MQTTCodecTest {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(MQTTCodecTest.class);
|
||||
|
||||
private final MQTTWireFormat wireFormat = new MQTTWireFormat();
|
||||
|
||||
private List<MQTTFrame> frames;
|
||||
private MQTTCodec codec;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
frames = new ArrayList<MQTTFrame>();
|
||||
codec = new MQTTCodec(new MQTTCodec.MQTTFrameSink() {
|
||||
|
||||
@Override
|
||||
public void onFrame(MQTTFrame mqttFrame) {
|
||||
frames.add(mqttFrame);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyConnectBytes() throws Exception {
|
||||
|
||||
CONNECT connect = new CONNECT();
|
||||
connect.cleanSession(true);
|
||||
connect.clientId(new UTF8Buffer(""));
|
||||
|
||||
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
|
||||
wireFormat.marshal(connect.encode(), output);
|
||||
Buffer marshalled = output.toBuffer();
|
||||
|
||||
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
|
||||
codec.parse(input, marshalled.length());
|
||||
|
||||
assertTrue(!frames.isEmpty());
|
||||
assertEquals(1, frames.size());
|
||||
|
||||
connect = new CONNECT().decode(frames.get(0));
|
||||
LOG.info("Unmarshalled: {}", connect);
|
||||
assertTrue(connect.cleanSession());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConnectWithCredentialsBackToBack() throws Exception {
|
||||
|
||||
CONNECT connect = new CONNECT();
|
||||
connect.cleanSession(false);
|
||||
connect.clientId(new UTF8Buffer("test"));
|
||||
connect.userName(new UTF8Buffer("user"));
|
||||
connect.password(new UTF8Buffer("pass"));
|
||||
|
||||
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
|
||||
wireFormat.marshal(connect.encode(), output);
|
||||
wireFormat.marshal(connect.encode(), output);
|
||||
Buffer marshalled = output.toBuffer();
|
||||
|
||||
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
|
||||
codec.parse(input, marshalled.length());
|
||||
|
||||
assertTrue(!frames.isEmpty());
|
||||
assertEquals(2, frames.size());
|
||||
|
||||
for (MQTTFrame frame : frames) {
|
||||
connect = new CONNECT().decode(frame);
|
||||
LOG.info("Unmarshalled: {}", connect);
|
||||
assertFalse(connect.cleanSession());
|
||||
assertEquals("user", connect.userName().toString());
|
||||
assertEquals("pass", connect.password().toString());
|
||||
assertEquals("test", connect.clientId().toString());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testProcessInChunks() throws Exception {
|
||||
|
||||
CONNECT connect = new CONNECT();
|
||||
connect.cleanSession(false);
|
||||
connect.clientId(new UTF8Buffer("test"));
|
||||
connect.userName(new UTF8Buffer("user"));
|
||||
connect.password(new UTF8Buffer("pass"));
|
||||
|
||||
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
|
||||
wireFormat.marshal(connect.encode(), output);
|
||||
Buffer marshalled = output.toBuffer();
|
||||
|
||||
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
|
||||
|
||||
int first = marshalled.length() / 2;
|
||||
int second = marshalled.length() - first;
|
||||
|
||||
codec.parse(input, first);
|
||||
codec.parse(input, second);
|
||||
|
||||
assertTrue(!frames.isEmpty());
|
||||
assertEquals(1, frames.size());
|
||||
|
||||
connect = new CONNECT().decode(frames.get(0));
|
||||
LOG.info("Unmarshalled: {}", connect);
|
||||
assertFalse(connect.cleanSession());
|
||||
|
||||
assertEquals("user", connect.userName().toString());
|
||||
assertEquals("pass", connect.password().toString());
|
||||
assertEquals("test", connect.clientId().toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testProcessInBytes() throws Exception {
|
||||
|
||||
CONNECT connect = new CONNECT();
|
||||
connect.cleanSession(false);
|
||||
connect.clientId(new UTF8Buffer("test"));
|
||||
connect.userName(new UTF8Buffer("user"));
|
||||
connect.password(new UTF8Buffer("pass"));
|
||||
|
||||
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
|
||||
wireFormat.marshal(connect.encode(), output);
|
||||
Buffer marshalled = output.toBuffer();
|
||||
|
||||
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
|
||||
|
||||
int size = marshalled.length();
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
codec.parse(input, 1);
|
||||
}
|
||||
|
||||
assertTrue(!frames.isEmpty());
|
||||
assertEquals(1, frames.size());
|
||||
|
||||
connect = new CONNECT().decode(frames.get(0));
|
||||
LOG.info("Unmarshalled: {}", connect);
|
||||
assertFalse(connect.cleanSession());
|
||||
|
||||
assertEquals("user", connect.userName().toString());
|
||||
assertEquals("pass", connect.password().toString());
|
||||
assertEquals("test", connect.clientId().toString());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.activemq.transport.mqtt;
|
||||
|
||||
/**
|
||||
* Run the basic tests with the NIO Transport.
|
||||
*/
|
||||
public class MQTTNIOSSLTest extends MQTTTest {
|
||||
|
||||
@Override
|
||||
public String getProtocolScheme() {
|
||||
return "mqtt+nio+ssl";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isUseSSL() {
|
||||
return true;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue