From 052d2931434fae255450887e9be288362e17b6fc Mon Sep 17 00:00:00 2001 From: Timothy Bish Date: Tue, 5 Aug 2014 12:43:36 -0400 Subject: [PATCH] https://issues.apache.org/jira/browse/AMQ-5308 Improve performance of the codec for large message processing. --- .../activemq/transport/mqtt/MQTTCodec.java | 112 ++++++++++------- .../transport/mqtt/MQTTCodecTest.java | 114 ++++++++++++++++++ 2 files changed, 185 insertions(+), 41 deletions(-) diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTCodec.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTCodec.java index c892dd101c..6970af73cd 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTCodec.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTCodec.java @@ -19,31 +19,27 @@ package org.apache.activemq.transport.mqtt; import java.io.IOException; import org.apache.activemq.transport.tcp.TcpTransport; +import org.fusesource.hawtbuf.Buffer; import org.fusesource.hawtbuf.DataByteArrayInputStream; -import org.fusesource.hawtbuf.DataByteArrayOutputStream; import org.fusesource.mqtt.codec.MQTTFrame; public class MQTTCodec { private final MQTTFrameSink frameSink; - private final DataByteArrayOutputStream currentCommand = new DataByteArrayOutputStream(); + private byte header; - private int contentLength = -1; - private int payLoadRead = 0; - - public interface MQTTFrameSink { - void onFrame(MQTTFrame mqttFrame); - } private FrameParser currentParser; - // Internal parsers implement this and we switch to the next as we go. - private interface FrameParser { + private final Buffer scratch = new Buffer(8 * 1024); + private Buffer currentBuffer; - void parse(DataByteArrayInputStream data, int readSize) throws IOException; - - void reset() throws IOException; + /** + * Sink for newly decoded MQTT Frames. + */ + public interface MQTTFrameSink { + void onFrame(MQTTFrame mqttFrame); } public MQTTCodec(MQTTFrameSink sink) { @@ -70,7 +66,16 @@ public class MQTTCodec { } private void processCommand() throws IOException { - MQTTFrame frame = new MQTTFrame(currentCommand.toBuffer().deepCopy()).header(header); + + Buffer frameContents = null; + if (currentBuffer == scratch) { + frameContents = scratch.deepCopy(); + } else { + frameContents = currentBuffer; + currentBuffer = null; + } + + MQTTFrame frame = new MQTTFrame(frameContents).header(header); frameSink.onFrame(frame); } @@ -93,6 +98,13 @@ public class MQTTCodec { //----- Frame parser implementations -------------------------------------// + private interface FrameParser { + + void parse(DataByteArrayInputStream data, int readSize) throws IOException; + + void reset() throws IOException; + } + private final FrameParser headerParser = new FrameParser() { @Override @@ -108,7 +120,9 @@ public class MQTTCodec { header = b; currentParser = initializeVariableLengthParser(); - currentParser.parse(data, readSize - 1); + if (readSize > 1) { + currentParser.parse(data, readSize - 1); + } return; } } @@ -116,32 +130,7 @@ public class MQTTCodec { @Override public void reset() throws IOException { header = -1; - } - }; - - private final FrameParser contentParser = new FrameParser() { - - @Override - public void parse(DataByteArrayInputStream data, int readSize) throws IOException { - int i = 0; - while (i++ < readSize) { - currentCommand.write(data.readByte()); - payLoadRead++; - - if (payLoadRead == contentLength) { - processCommand(); - currentParser = initializeHeaderParser(); - currentParser.parse(data, readSize - i); - return; - } - } - } - - @Override - public void reset() throws IOException { contentLength = -1; - payLoadRead = 0; - currentCommand.reset(); } }; @@ -166,7 +155,11 @@ public class MQTTCodec { currentParser = initializeContentParser(); contentLength = length; } - currentParser.parse(data, readSize - i); + + readSize = readSize - i; + if (readSize > 0) { + currentParser.parse(data, readSize); + } return; } } @@ -179,4 +172,41 @@ public class MQTTCodec { length = 0; } }; + + private final FrameParser contentParser = new FrameParser() { + + private int payLoadRead = 0; + + @Override + public void parse(DataByteArrayInputStream data, int readSize) throws IOException { + if (currentBuffer == null) { + if (contentLength < scratch.length()) { + currentBuffer = scratch; + currentBuffer.length = contentLength; + } else { + currentBuffer = new Buffer(contentLength); + } + } + + int length = Math.min(readSize, contentLength - payLoadRead); + payLoadRead += data.read(currentBuffer.data, payLoadRead, length); + + if (payLoadRead == contentLength) { + processCommand(); + currentParser = initializeHeaderParser(); + readSize = readSize - payLoadRead; + if (readSize > 0) { + currentParser.parse(data, readSize); + } + } + } + + @Override + public void reset() throws IOException { + contentLength = -1; + payLoadRead = 0; + scratch.reset(); + currentBuffer = null; + } + }; } diff --git a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTCodecTest.java b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTCodecTest.java index 31af1ab8e7..a1b087b05f 100644 --- a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTCodecTest.java +++ b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTCodecTest.java @@ -22,13 +22,18 @@ import static org.junit.Assert.assertTrue; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.TimeUnit; import org.fusesource.hawtbuf.Buffer; import org.fusesource.hawtbuf.DataByteArrayInputStream; import org.fusesource.hawtbuf.DataByteArrayOutputStream; import org.fusesource.hawtbuf.UTF8Buffer; +import org.fusesource.mqtt.client.QoS; +import org.fusesource.mqtt.client.Topic; import org.fusesource.mqtt.codec.CONNECT; import org.fusesource.mqtt.codec.MQTTFrame; +import org.fusesource.mqtt.codec.PUBLISH; +import org.fusesource.mqtt.codec.SUBSCRIBE; import org.junit.Before; import org.junit.Test; import org.slf4j.Logger; @@ -46,6 +51,9 @@ public class MQTTCodecTest { private List frames; private MQTTCodec codec; + private final int MESSAGE_SIZE = 5 * 1024 * 1024; + private final int ITERATIONS = 1000; + @Before public void setUp() throws Exception { frames = new ArrayList(); @@ -80,6 +88,45 @@ public class MQTTCodecTest { assertTrue(connect.cleanSession()); } + @Test + public void testConnectThenSubscribe() throws Exception { + + CONNECT connect = new CONNECT(); + connect.cleanSession(true); + connect.clientId(new UTF8Buffer("")); + + DataByteArrayOutputStream output = new DataByteArrayOutputStream(); + wireFormat.marshal(connect.encode(), output); + Buffer marshalled = output.toBuffer(); + + DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled); + codec.parse(input, marshalled.length()); + + assertTrue(!frames.isEmpty()); + assertEquals(1, frames.size()); + + connect = new CONNECT().decode(frames.get(0)); + LOG.info("Unmarshalled: {}", connect); + assertTrue(connect.cleanSession()); + + frames.clear(); + + SUBSCRIBE subscribe = new SUBSCRIBE(); + subscribe.topics(new Topic[] {new Topic("TEST", QoS.EXACTLY_ONCE) }); + + output = new DataByteArrayOutputStream(); + wireFormat.marshal(subscribe.encode(), output); + marshalled = output.toBuffer(); + + input = new DataByteArrayInputStream(marshalled); + codec.parse(input, marshalled.length()); + + assertTrue(!frames.isEmpty()); + assertEquals(1, frames.size()); + + subscribe = new SUBSCRIBE().decode(frames.get(0)); + } + @Test public void testConnectWithCredentialsBackToBack() throws Exception { @@ -175,4 +222,71 @@ public class MQTTCodecTest { assertEquals("pass", connect.password().toString()); assertEquals("test", connect.clientId().toString()); } + + @Test + public void testMessageDecoding() throws Exception { + + byte[] CONTENTS = new byte[MESSAGE_SIZE]; + for (int i = 0; i < MESSAGE_SIZE; i++) { + CONTENTS[i] = 'a'; + } + + PUBLISH publish = new PUBLISH(); + + publish.dup(false); + publish.messageId((short) 127); + publish.qos(QoS.AT_LEAST_ONCE); + publish.payload(new Buffer(CONTENTS)); + publish.topicName(new UTF8Buffer("TOPIC")); + + DataByteArrayOutputStream output = new DataByteArrayOutputStream(); + wireFormat.marshal(publish.encode(), output); + Buffer marshalled = output.toBuffer(); + + DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled); + codec.parse(input, marshalled.length()); + + assertTrue(!frames.isEmpty()); + assertEquals(1, frames.size()); + + publish = new PUBLISH().decode(frames.get(0)); + assertFalse(publish.dup()); + assertEquals(MESSAGE_SIZE, publish.payload().length()); + } + + @Test + public void testMessageDecodingPerformance() throws Exception { + + byte[] CONTENTS = new byte[MESSAGE_SIZE]; + for (int i = 0; i < MESSAGE_SIZE; i++) { + CONTENTS[i] = 'a'; + } + + PUBLISH publish = new PUBLISH(); + + publish.dup(false); + publish.messageId((short) 127); + publish.qos(QoS.AT_LEAST_ONCE); + publish.payload(new Buffer(CONTENTS)); + publish.topicName(new UTF8Buffer("TOPIC")); + + DataByteArrayOutputStream output = new DataByteArrayOutputStream(); + wireFormat.marshal(publish.encode(), output); + Buffer marshalled = output.toBuffer(); + + long startTime = System.currentTimeMillis(); + + for (int i = 0; i < ITERATIONS; ++i) { + DataByteArrayInputStream input = new DataByteArrayInputStream(marshalled); + codec.parse(input, marshalled.length()); + + assertTrue(!frames.isEmpty()); + publish = new PUBLISH().decode(frames.get(0)); + frames.clear(); + } + + long duration = System.currentTimeMillis() - startTime; + + LOG.info("Total time to process: {}", TimeUnit.MILLISECONDS.toSeconds(duration)); + } }