Improve performance of the codec for large message processing.
This commit is contained in:
Timothy Bish 2014-08-05 12:43:36 -04:00
parent c99e2d8372
commit 052d293143
2 changed files with 185 additions and 41 deletions

View File

@ -19,31 +19,27 @@ package org.apache.activemq.transport.mqtt;
import java.io.IOException;
import org.apache.activemq.transport.tcp.TcpTransport;
import org.fusesource.hawtbuf.Buffer;
import org.fusesource.hawtbuf.DataByteArrayInputStream;
import org.fusesource.hawtbuf.DataByteArrayOutputStream;
import org.fusesource.mqtt.codec.MQTTFrame;
public class MQTTCodec {
private final MQTTFrameSink frameSink;
private final DataByteArrayOutputStream currentCommand = new DataByteArrayOutputStream();
private byte header;
private int contentLength = -1;
private int payLoadRead = 0;
public interface MQTTFrameSink {
void onFrame(MQTTFrame mqttFrame);
}
private FrameParser currentParser;
// Internal parsers implement this and we switch to the next as we go.
private interface FrameParser {
private final Buffer scratch = new Buffer(8 * 1024);
private Buffer currentBuffer;
void parse(DataByteArrayInputStream data, int readSize) throws IOException;
void reset() throws IOException;
/**
* Sink for newly decoded MQTT Frames.
*/
public interface MQTTFrameSink {
void onFrame(MQTTFrame mqttFrame);
}
public MQTTCodec(MQTTFrameSink sink) {
@ -70,7 +66,16 @@ public class MQTTCodec {
}
private void processCommand() throws IOException {
MQTTFrame frame = new MQTTFrame(currentCommand.toBuffer().deepCopy()).header(header);
Buffer frameContents = null;
if (currentBuffer == scratch) {
frameContents = scratch.deepCopy();
} else {
frameContents = currentBuffer;
currentBuffer = null;
}
MQTTFrame frame = new MQTTFrame(frameContents).header(header);
frameSink.onFrame(frame);
}
@ -93,6 +98,13 @@ public class MQTTCodec {
//----- Frame parser implementations -------------------------------------//
private interface FrameParser {
void parse(DataByteArrayInputStream data, int readSize) throws IOException;
void reset() throws IOException;
}
private final FrameParser headerParser = new FrameParser() {
@Override
@ -108,7 +120,9 @@ public class MQTTCodec {
header = b;
currentParser = initializeVariableLengthParser();
if (readSize > 1) {
currentParser.parse(data, readSize - 1);
}
return;
}
}
@ -116,32 +130,7 @@ public class MQTTCodec {
@Override
public void reset() throws IOException {
header = -1;
}
};
private final FrameParser contentParser = new FrameParser() {
@Override
public void parse(DataByteArrayInputStream data, int readSize) throws IOException {
int i = 0;
while (i++ < readSize) {
currentCommand.write(data.readByte());
payLoadRead++;
if (payLoadRead == contentLength) {
processCommand();
currentParser = initializeHeaderParser();
currentParser.parse(data, readSize - i);
return;
}
}
}
@Override
public void reset() throws IOException {
contentLength = -1;
payLoadRead = 0;
currentCommand.reset();
}
};
@ -166,7 +155,11 @@ public class MQTTCodec {
currentParser = initializeContentParser();
contentLength = length;
}
currentParser.parse(data, readSize - i);
readSize = readSize - i;
if (readSize > 0) {
currentParser.parse(data, readSize);
}
return;
}
}
@ -179,4 +172,41 @@ public class MQTTCodec {
length = 0;
}
};
private final FrameParser contentParser = new FrameParser() {
private int payLoadRead = 0;
@Override
public void parse(DataByteArrayInputStream data, int readSize) throws IOException {
if (currentBuffer == null) {
if (contentLength < scratch.length()) {
currentBuffer = scratch;
currentBuffer.length = contentLength;
} else {
currentBuffer = new Buffer(contentLength);
}
}
int length = Math.min(readSize, contentLength - payLoadRead);
payLoadRead += data.read(currentBuffer.data, payLoadRead, length);
if (payLoadRead == contentLength) {
processCommand();
currentParser = initializeHeaderParser();
readSize = readSize - payLoadRead;
if (readSize > 0) {
currentParser.parse(data, readSize);
}
}
}
@Override
public void reset() throws IOException {
contentLength = -1;
payLoadRead = 0;
scratch.reset();
currentBuffer = null;
}
};
}

View File

@ -22,13 +22,18 @@ import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.fusesource.hawtbuf.Buffer;
import org.fusesource.hawtbuf.DataByteArrayInputStream;
import org.fusesource.hawtbuf.DataByteArrayOutputStream;
import org.fusesource.hawtbuf.UTF8Buffer;
import org.fusesource.mqtt.client.QoS;
import org.fusesource.mqtt.client.Topic;
import org.fusesource.mqtt.codec.CONNECT;
import org.fusesource.mqtt.codec.MQTTFrame;
import org.fusesource.mqtt.codec.PUBLISH;
import org.fusesource.mqtt.codec.SUBSCRIBE;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
@ -46,6 +51,9 @@ public class MQTTCodecTest {
private List<MQTTFrame> frames;
private MQTTCodec codec;
private final int MESSAGE_SIZE = 5 * 1024 * 1024;
private final int ITERATIONS = 1000;
@Before
public void setUp() throws Exception {
frames = new ArrayList<MQTTFrame>();
@ -80,6 +88,45 @@ public class MQTTCodecTest {
assertTrue(connect.cleanSession());
}
@Test
public void testConnectThenSubscribe() throws Exception {
CONNECT connect = new CONNECT();
connect.cleanSession(true);
connect.clientId(new UTF8Buffer(""));
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
wireFormat.marshal(connect.encode(), output);
Buffer marshalled = output.toBuffer();
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
codec.parse(input, marshalled.length());
assertTrue(!frames.isEmpty());
assertEquals(1, frames.size());
connect = new CONNECT().decode(frames.get(0));
LOG.info("Unmarshalled: {}", connect);
assertTrue(connect.cleanSession());
frames.clear();
SUBSCRIBE subscribe = new SUBSCRIBE();
subscribe.topics(new Topic[] {new Topic("TEST", QoS.EXACTLY_ONCE) });
output = new DataByteArrayOutputStream();
wireFormat.marshal(subscribe.encode(), output);
marshalled = output.toBuffer();
input = new DataByteArrayInputStream(marshalled);
codec.parse(input, marshalled.length());
assertTrue(!frames.isEmpty());
assertEquals(1, frames.size());
subscribe = new SUBSCRIBE().decode(frames.get(0));
}
@Test
public void testConnectWithCredentialsBackToBack() throws Exception {
@ -175,4 +222,71 @@ public class MQTTCodecTest {
assertEquals("pass", connect.password().toString());
assertEquals("test", connect.clientId().toString());
}
@Test
public void testMessageDecoding() throws Exception {
byte[] CONTENTS = new byte[MESSAGE_SIZE];
for (int i = 0; i < MESSAGE_SIZE; i++) {
CONTENTS[i] = 'a';
}
PUBLISH publish = new PUBLISH();
publish.dup(false);
publish.messageId((short) 127);
publish.qos(QoS.AT_LEAST_ONCE);
publish.payload(new Buffer(CONTENTS));
publish.topicName(new UTF8Buffer("TOPIC"));
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
wireFormat.marshal(publish.encode(), output);
Buffer marshalled = output.toBuffer();
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
codec.parse(input, marshalled.length());
assertTrue(!frames.isEmpty());
assertEquals(1, frames.size());
publish = new PUBLISH().decode(frames.get(0));
assertFalse(publish.dup());
assertEquals(MESSAGE_SIZE, publish.payload().length());
}
@Test
public void testMessageDecodingPerformance() throws Exception {
byte[] CONTENTS = new byte[MESSAGE_SIZE];
for (int i = 0; i < MESSAGE_SIZE; i++) {
CONTENTS[i] = 'a';
}
PUBLISH publish = new PUBLISH();
publish.dup(false);
publish.messageId((short) 127);
publish.qos(QoS.AT_LEAST_ONCE);
publish.payload(new Buffer(CONTENTS));
publish.topicName(new UTF8Buffer("TOPIC"));
DataByteArrayOutputStream output = new DataByteArrayOutputStream();
wireFormat.marshal(publish.encode(), output);
Buffer marshalled = output.toBuffer();
long startTime = System.currentTimeMillis();
for (int i = 0; i < ITERATIONS; ++i) {
DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled);
codec.parse(input, marshalled.length());
assertTrue(!frames.isEmpty());
publish = new PUBLISH().decode(frames.get(0));
frames.clear();
}
long duration = System.currentTimeMillis() - startTime;
LOG.info("Total time to process: {}", TimeUnit.MILLISECONDS.toSeconds(duration));
}
}