diff --git a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java index 2b4be15cd2..3062d924e9 100644 --- a/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java +++ b/activemq-http/src/main/java/org/apache/activemq/transport/ws/jetty9/MQTTSocket.java @@ -21,26 +21,33 @@ import java.nio.ByteBuffer; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.activemq.transport.mqtt.MQTTCodec; import org.apache.activemq.transport.ws.AbstractMQTTSocket; import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.IOExceptionSupport; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.WebSocketListener; +import org.fusesource.hawtbuf.Buffer; +import org.fusesource.hawtbuf.DataByteArrayInputStream; import org.fusesource.mqtt.codec.DISCONNECT; import org.fusesource.mqtt.codec.MQTTFrame; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener { +public class MQTTSocket extends AbstractMQTTSocket implements MQTTCodec.MQTTFrameSink, WebSocketListener { private static final Logger LOG = LoggerFactory.getLogger(MQTTSocket.class); private final int ORDERLY_CLOSE_TIMEOUT = 10; private Session session; - final AtomicBoolean receivedDisconnect = new AtomicBoolean(); + private final AtomicBoolean receivedDisconnect = new AtomicBoolean(); + + private final MQTTCodec codec; public MQTTSocket(String remoteAddress) { super(remoteAddress); + + this.codec = new MQTTCodec(this, getWireFormat()); } @Override @@ -78,11 +85,7 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener protocolLock.lock(); try { receiveCounter += length; - MQTTFrame frame = (MQTTFrame)wireFormat.unmarshal(new ByteSequence(bytes, offset, length)); - if (frame.messageType() == DISCONNECT.TYPE) { - receivedDisconnect.set(true); - } - getProtocolConverter().onMQTTCommand(frame); + codec.parse(new DataByteArrayInputStream(new Buffer(bytes, offset, length)), length); } catch (Exception e) { onException(IOExceptionSupport.create(e)); } finally { @@ -127,4 +130,18 @@ public class MQTTSocket extends AbstractMQTTSocket implements WebSocketListener private static int getDefaultSendTimeOut() { return Integer.getInteger("org.apache.activemq.transport.ws.MQTTSocket.sendTimeout", 30); } + + //----- MQTTCodec Frame Sink event point ---------------------------------// + + @Override + public void onFrame(MQTTFrame mqttFrame) { + try { + if (mqttFrame.messageType() == DISCONNECT.TYPE) { + receivedDisconnect.set(true); + } + getProtocolConverter().onMQTTCommand(mqttFrame); + } catch (Exception e) { + onException(IOExceptionSupport.create(e)); + } + } } diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java index f14da2344c..fd42dc64c9 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSConnection.java @@ -58,8 +58,9 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe private final CountDownLatch connectLatch = new CountDownLatch(1); private final MQTTWireFormat wireFormat = new MQTTWireFormat(); - private final BlockingQueue prefetch = new LinkedBlockingDeque(); + private final BlockingQueue prefetch = new LinkedBlockingDeque<>(); + private boolean writePartialFrames; private int closeCode = -1; private String closeMessage; @@ -96,8 +97,7 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe public void connect(CONNECT command) throws Exception { checkConnected(); - ByteSequence payload = wireFormat.marshal(command.encode()); - connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data)); + sendBytes(wireFormat.marshal(command.encode())); MQTTFrame incoming = receive(15, TimeUnit.SECONDS); @@ -117,22 +117,19 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe } DISCONNECT command = new DISCONNECT(); - ByteSequence payload = wireFormat.marshal(command.encode()); - connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data)); + sendBytes(wireFormat.marshal(command.encode())); } //---- Send methods ------------------------------------------------------// public void sendFrame(MQTTFrame frame) throws Exception { checkConnected(); - ByteSequence payload = wireFormat.marshal(frame); - connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data)); + sendBytes(wireFormat.marshal(frame)); } public void keepAlive() throws Exception { checkConnected(); - ByteSequence payload = wireFormat.marshal(new PINGREQ().encode()); - connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data)); + sendBytes(wireFormat.marshal(new PINGREQ().encode())); } //----- Receive methods --------------------------------------------------// @@ -172,6 +169,15 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe return closeMessage; } + public boolean isWritePartialFrames() { + return writePartialFrames; + } + + public MQTTWSConnection setWritePartialFrames(boolean value) { + this.writePartialFrames = value; + return this; + } + //----- WebSocket callback handlers --------------------------------------// @Override @@ -246,6 +252,17 @@ public class MQTTWSConnection extends WebSocketAdapter implements WebSocketListe //----- Internal implementation ------------------------------------------// + private void sendBytes(ByteSequence payload) throws IOException { + if (!isWritePartialFrames()) { + connection.getRemote().sendBytes(ByteBuffer.wrap(payload.data, payload.offset, payload.length)); + } else { + connection.getRemote().sendBytes(ByteBuffer.wrap( + payload.data, payload.offset, payload.length / 2)); + connection.getRemote().sendBytes(ByteBuffer.wrap( + payload.data, payload.offset + payload.length / 2, payload.length / 2)); + } + } + private void checkConnected() throws IOException { if (!isConnected()) { throw new IOException("MQTT WS Connection is closed."); diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSTransportTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSTransportTest.java index bf19259b85..f0a9ac9acc 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSTransportTest.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSSTransportTest.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -16,8 +16,28 @@ */ package org.apache.activemq.transport.ws; +import java.util.Arrays; +import java.util.Collection; + +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) public class MQTTWSSTransportTest extends MQTTWSTransportTest { + @Parameters(name="{0}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {"complete-frames", false}, + {"partial-frames", false} + }); + } + + public MQTTWSSTransportTest(String testName, boolean partialFrames) { + super(testName, partialFrames); + } + @Override protected String getWSConnectorURI() { return "wss://localhost:61623"; diff --git a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSTransportTest.java b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSTransportTest.java index f48a110a62..74652453f0 100644 --- a/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSTransportTest.java +++ b/activemq-http/src/test/java/org/apache/activemq/transport/ws/MQTTWSTransportTest.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -20,6 +20,8 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -36,13 +38,31 @@ import org.fusesource.mqtt.codec.PINGREQ; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +@RunWith(Parameterized.class) public class MQTTWSTransportTest extends WSTransportTestSupport { protected WebSocketClient wsClient; protected MQTTWSConnection wsMQTTConnection; protected ClientUpgradeRequest request; + protected boolean partialFrames; + + @Parameters(name="{0}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {"complete-frames", false}, + {"partial-frames", true} + }); + } + + public MQTTWSTransportTest(String testName, boolean partialFrames) { + this.partialFrames = partialFrames; + } + @Override @Before public void setUp() throws Exception { @@ -54,7 +74,7 @@ public class MQTTWSTransportTest extends WSTransportTestSupport { request = new ClientUpgradeRequest(); request.setSubProtocols("mqttv3.1"); - wsMQTTConnection = new MQTTWSConnection(); + wsMQTTConnection = new MQTTWSConnection().setWritePartialFrames(partialFrames); wsClient.connect(wsMQTTConnection, wsConnectUri, request); if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) { @@ -79,7 +99,7 @@ public class MQTTWSTransportTest extends WSTransportTestSupport { for (int i = 0; i < 10; ++i) { testConnect(); - wsMQTTConnection = new MQTTWSConnection(); + wsMQTTConnection = new MQTTWSConnection().setWritePartialFrames(partialFrames); wsClient.connect(wsMQTTConnection, wsConnectUri, request); if (!wsMQTTConnection.awaitConnection(30, TimeUnit.SECONDS)) {