Fixed the MQTTCodec to properly handle frames that come in split up or
bunched together.
This commit is contained in:
Timothy Bish 2014-08-04 18:58:03 -04:00
parent 9743dbddb6
commit 7c04ead460
7 changed files with 414 additions and 122 deletions

View File

@ -18,144 +18,165 @@ package org.apache.activemq.transport.mqtt;
import java.io.IOException;
import javax.jms.JMSException;
import org.apache.activemq.transport.tcp.TcpTransport;
import org.fusesource.hawtbuf.DataByteArrayInputStream;
import org.fusesource.hawtbuf.DataByteArrayOutputStream;
import org.fusesource.mqtt.codec.*;
import org.fusesource.mqtt.codec.MQTTFrame;
public class MQTTCodec {
TcpTransport transport;
private final MQTTFrameSink frameSink;
private final DataByteArrayOutputStream currentCommand = new DataByteArrayOutputStream();
private byte header;
DataByteArrayOutputStream currentCommand = new DataByteArrayOutputStream();
boolean processedHeader = false;
String action;
byte header;
int contentLength = -1;
int previousByte = -1;
int payLoadRead = 0;
private int contentLength = -1;
private int payLoadRead = 0;
public MQTTCodec(TcpTransport transport) {
this.transport = transport;
public interface MQTTFrameSink {
void onFrame(MQTTFrame mqttFrame);
}
private FrameParser currentParser;
// Internal parsers implement this and we switch to the next as we go.
private interface FrameParser {
void parse(DataByteArrayInputStream data, int readSize) throws IOException;
void reset() throws IOException;
}
public MQTTCodec(MQTTFrameSink sink) {
this.frameSink = sink;
}
public MQTTCodec(final TcpTransport transport) {
this.frameSink = new MQTTFrameSink() {
@Override
public void onFrame(MQTTFrame mqttFrame) {
transport.doConsume(mqttFrame);
}
};
}
public void parse(DataByteArrayInputStream input, int readSize) throws Exception {
int i = 0;
byte b;
while (i++ < readSize) {
b = input.readByte();
// skip repeating nulls
if (!processedHeader && b == 0) {
previousByte = 0;
continue;
}
if (!processedHeader) {
i += processHeader(b, input);
if (contentLength == 0) {
processCommand();
}
} else {
if (contentLength == -1) {
// end of command reached, unmarshal
if (b == 0) {
processCommand();
} else {
currentCommand.write(b);
}
} else {
// read desired content length
if (payLoadRead == contentLength) {
processCommand();
i += processHeader(b, input);
} else {
currentCommand.write(b);
payLoadRead++;
}
}
}
previousByte = b;
}
if (processedHeader && payLoadRead == contentLength) {
processCommand();
if (currentParser == null) {
currentParser = initializeHeaderParser();
}
// Parser stack will run until current incoming data has all been consumed.
currentParser.parse(input, readSize);
}
/**
* sets the content length
*
* @return number of bytes read
*/
private int processHeader(byte header, DataByteArrayInputStream input) {
this.header = header;
byte digit;
int multiplier = 1;
int read = 0;
int length = 0;
do {
digit = input.readByte();
length += (digit & 0x7F) * multiplier;
multiplier <<= 7;
read++;
} while ((digit & 0x80) != 0);
contentLength = length;
processedHeader = true;
return read;
}
private void processCommand() throws Exception {
private void processCommand() throws IOException {
MQTTFrame frame = new MQTTFrame(currentCommand.toBuffer().deepCopy()).header(header);
transport.doConsume(frame);
processedHeader = false;
currentCommand.reset();
contentLength = -1;
payLoadRead = 0;
frameSink.onFrame(frame);
}
public static String commandType(byte header) throws IOException, JMSException {
//----- Prepare the current frame parser for use -------------------------//
byte messageType = (byte) ((header & 0xF0) >>> 4);
switch (messageType) {
case PINGREQ.TYPE: {
return "PINGREQ";
private FrameParser initializeHeaderParser() throws IOException {
headerParser.reset();
return headerParser;
}
private FrameParser initializeVariableLengthParser() throws IOException {
variableLengthParser.reset();
return variableLengthParser;
}
private FrameParser initializeContentParser() throws IOException {
contentParser.reset();
return contentParser;
}
//----- Frame parser implementations -------------------------------------//
private final FrameParser headerParser = new FrameParser() {
@Override
public void parse(DataByteArrayInputStream data, int readSize) throws IOException {
int i = 0;
while (i++ < readSize) {
byte b = data.readByte();
// skip repeating nulls
if (b == 0) {
continue;
}
header = b;
currentParser = initializeVariableLengthParser();
currentParser.parse(data, readSize - 1);
return;
}
case CONNECT.TYPE: {
return "CONNECT";
}
case DISCONNECT.TYPE: {
return "DISCONNECT";
}
case SUBSCRIBE.TYPE: {
return "SUBSCRIBE";
}
case UNSUBSCRIBE.TYPE: {
return "UNSUBSCRIBE";
}
case PUBLISH.TYPE: {
return "PUBLISH";
}
case PUBACK.TYPE: {
return "PUBACK";
}
case PUBREC.TYPE: {
return "PUBREC";
}
case PUBREL.TYPE: {
return "PUBREL";
}
case PUBCOMP.TYPE: {
return "PUBCOMP";
}
default:
return "UNKNOWN";
}
}
@Override
public void reset() throws IOException {
header = -1;
}
};
private final FrameParser contentParser = new FrameParser() {
@Override
public void parse(DataByteArrayInputStream data, int readSize) throws IOException {
int i = 0;
while (i++ < readSize) {
currentCommand.write(data.readByte());
payLoadRead++;
if (payLoadRead == contentLength) {
processCommand();
currentParser = initializeHeaderParser();
currentParser.parse(data, readSize - i);
return;
}
}
}
@Override
public void reset() throws IOException {
contentLength = -1;
payLoadRead = 0;
currentCommand.reset();
}
};
private final FrameParser variableLengthParser = new FrameParser() {
private byte digit;
private int multiplier = 1;
private int length;
@Override
public void parse(DataByteArrayInputStream data, int readSize) throws IOException {
int i = 0;
while (i++ < readSize) {
digit = data.readByte();
length += (digit & 0x7F) * multiplier;
multiplier <<= 7;
if ((digit & 0x80) == 0) {
if (length == 0) {
processCommand();
currentParser = initializeHeaderParser();
} else {
currentParser = initializeContentParser();
contentLength = length;
}
currentParser.parse(data, readSize - i);
return;
}
}
}
@Override
public void reset() throws IOException {
digit = 0;
multiplier = 1;
length = 0;
}
};
}

View File

@ -23,13 +23,14 @@ import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import javax.net.SocketFactory;
import org.apache.activemq.transport.nio.NIOSSLTransport;
import org.apache.activemq.wireformat.WireFormat;
import org.fusesource.hawtbuf.DataByteArrayInputStream;
public class MQTTNIOSSLTransport extends NIOSSLTransport {
MQTTCodec codec;
private MQTTCodec codec;
public MQTTNIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
super(wireFormat, socketFactory, remoteLocation, localLocation);
@ -55,5 +56,4 @@ public class MQTTNIOSSLTransport extends NIOSSLTransport {
DataByteArrayInputStream dis = new DataByteArrayInputStream(fill);
codec.parse(dis, fill.length);
}
}

View File

@ -16,6 +16,17 @@
*/
package org.apache.activemq.transport.mqtt;
import org.fusesource.mqtt.codec.CONNECT;
import org.fusesource.mqtt.codec.DISCONNECT;
import org.fusesource.mqtt.codec.PINGREQ;
import org.fusesource.mqtt.codec.PUBACK;
import org.fusesource.mqtt.codec.PUBCOMP;
import org.fusesource.mqtt.codec.PUBLISH;
import org.fusesource.mqtt.codec.PUBREC;
import org.fusesource.mqtt.codec.PUBREL;
import org.fusesource.mqtt.codec.SUBSCRIBE;
import org.fusesource.mqtt.codec.UNSUBSCRIBE;
/**
* A set of static methods useful for handling MQTT based client connections.
*/
@ -70,4 +81,41 @@ public class MQTTProtocolSupport {
public static String convertActiveMQToMQTT(String destinationName) {
return destinationName.replace('.', '/');
}
/**
* Given an MQTT header byte, determine the command type that the header
* represents.
*
* @param header
* the byte value for the MQTT frame header.
*
* @return a string value for the given command type.
*/
public static String commandType(byte header) {
byte messageType = (byte) ((header & 0xF0) >>> 4);
switch (messageType) {
case PINGREQ.TYPE:
return "PINGREQ";
case CONNECT.TYPE:
return "CONNECT";
case DISCONNECT.TYPE:
return "DISCONNECT";
case SUBSCRIBE.TYPE:
return "SUBSCRIBE";
case UNSUBSCRIBE.TYPE:
return "UNSUBSCRIBE";
case PUBLISH.TYPE:
return "PUBLISH";
case PUBACK.TYPE:
return "PUBACK";
case PUBREC.TYPE:
return "PUBREC";
case PUBREL.TYPE:
return "PUBREL";
case PUBCOMP.TYPE:
return "PUBCOMP";
default:
return "UNKNOWN";
}
}
}

View File

@ -34,6 +34,14 @@ import org.apache.activemq.security.TempDestinationAuthorizationEntry;
*/
public class MQTTAuthTestSupport extends MQTTTestSupport {
public MQTTAuthTestSupport() {
super();
}
public MQTTAuthTestSupport(String connectorScheme, boolean useSSL) {
super(connectorScheme, useSSL);
}
@Override
protected BrokerPlugin configureAuthentication() throws Exception {
List<AuthenticationUser> users = new ArrayList<AuthenticationUser>();

View File

@ -66,11 +66,15 @@ public class MQTTAuthTests extends MQTTAuthTestSupport {
return Arrays.asList(new Object[][] {
{"mqtt", false},
{"mqtt+ssl", true},
{"mqtt+nio", false}
// TODO - Fails {"mqtt+nio+ssl", true}
{"mqtt+nio", false},
{"mqtt+nio+ssl", true}
});
}
public MQTTAuthTests(String connectorScheme, boolean useSSL) {
super(connectorScheme, useSSL);
}
@Test(timeout = 60 * 1000)
public void testAnonymousUserConnect() throws Exception {
MQTT mqtt = createMQTTConnection();

View File

@ -0,0 +1,178 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.mqtt;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.List;
import org.fusesource.hawtbuf.Buffer;
import org.fusesource.hawtbuf.DataByteArrayInputStream;
import org.fusesource.hawtbuf.DataByteArrayOutputStream;
import org.fusesource.hawtbuf.UTF8Buffer;
import org.fusesource.mqtt.codec.CONNECT;
import org.fusesource.mqtt.codec.MQTTFrame;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Tests the functionality of the MQTTCodec class.
*/
public class MQTTCodecTest {
private static final Logger LOG = LoggerFactory.getLogger(MQTTCodecTest.class);
private final MQTTWireFormat wireFormat = new MQTTWireFormat();
private List<MQTTFrame> frames;
private MQTTCodec codec;
@Before
public void setUp() throws Exception {
frames = new ArrayList<MQTTFrame>();
codec = new MQTTCodec(new MQTTCodec.MQTTFrameSink() {
@Override
public void onFrame(MQTTFrame mqttFrame) {
frames.add(mqttFrame);
}
});
}
@Test
public void testEmptyConnectBytes() throws Exception {
CONNECT connect = new CONNECT();
connect.cleanSession(true);
connect.clientId(new UTF8Buffer(""));
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
wireFormat.marshal(connect.encode(), output);
Buffer marshalled = output.toBuffer();
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
codec.parse(input, marshalled.length());
assertTrue(!frames.isEmpty());
assertEquals(1, frames.size());
connect = new CONNECT().decode(frames.get(0));
LOG.info("Unmarshalled: {}", connect);
assertTrue(connect.cleanSession());
}
@Test
public void testConnectWithCredentialsBackToBack() throws Exception {
CONNECT connect = new CONNECT();
connect.cleanSession(false);
connect.clientId(new UTF8Buffer("test"));
connect.userName(new UTF8Buffer("user"));
connect.password(new UTF8Buffer("pass"));
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
wireFormat.marshal(connect.encode(), output);
wireFormat.marshal(connect.encode(), output);
Buffer marshalled = output.toBuffer();
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
codec.parse(input, marshalled.length());
assertTrue(!frames.isEmpty());
assertEquals(2, frames.size());
for (MQTTFrame frame : frames) {
connect = new CONNECT().decode(frame);
LOG.info("Unmarshalled: {}", connect);
assertFalse(connect.cleanSession());
assertEquals("user", connect.userName().toString());
assertEquals("pass", connect.password().toString());
assertEquals("test", connect.clientId().toString());
}
}
@Test
public void testProcessInChunks() throws Exception {
CONNECT connect = new CONNECT();
connect.cleanSession(false);
connect.clientId(new UTF8Buffer("test"));
connect.userName(new UTF8Buffer("user"));
connect.password(new UTF8Buffer("pass"));
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
wireFormat.marshal(connect.encode(), output);
Buffer marshalled = output.toBuffer();
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
int first = marshalled.length() / 2;
int second = marshalled.length() - first;
codec.parse(input, first);
codec.parse(input, second);
assertTrue(!frames.isEmpty());
assertEquals(1, frames.size());
connect = new CONNECT().decode(frames.get(0));
LOG.info("Unmarshalled: {}", connect);
assertFalse(connect.cleanSession());
assertEquals("user", connect.userName().toString());
assertEquals("pass", connect.password().toString());
assertEquals("test", connect.clientId().toString());
}
@Test
public void testProcessInBytes() throws Exception {
CONNECT connect = new CONNECT();
connect.cleanSession(false);
connect.clientId(new UTF8Buffer("test"));
connect.userName(new UTF8Buffer("user"));
connect.password(new UTF8Buffer("pass"));
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
wireFormat.marshal(connect.encode(), output);
Buffer marshalled = output.toBuffer();
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
int size = marshalled.length();
for (int i = 0; i < size; ++i) {
codec.parse(input, 1);
}
assertTrue(!frames.isEmpty());
assertEquals(1, frames.size());
connect = new CONNECT().decode(frames.get(0));
LOG.info("Unmarshalled: {}", connect);
assertFalse(connect.cleanSession());
assertEquals("user", connect.userName().toString());
assertEquals("pass", connect.password().toString());
assertEquals("test", connect.clientId().toString());
}
}

View File

@ -0,0 +1,33 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.mqtt;
/**
* Run the basic tests with the NIO Transport.
*/
public class MQTTNIOSSLTest extends MQTTTest {
@Override
public String getProtocolScheme() {
return "mqtt+nio+ssl";
}
@Override
public boolean isUseSSL() {
return true;
}
}