diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java index be42d2fefb..9800be5ee4 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java @@ -26,7 +26,6 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.handler.codec.mqtt.MqttConnectReturnCode; import org.apache.activemq.artemis.api.core.client.ActiveMQClient; import org.apache.activemq.artemis.core.server.ActiveMQServer; -import org.apache.activemq.artemis.core.server.ServerMessage; import org.apache.activemq.artemis.core.server.ServerSession; import org.apache.activemq.artemis.core.server.impl.ServerSessionImpl; import org.apache.activemq.artemis.utils.ConcurrentHashSet; @@ -45,6 +44,16 @@ public class MQTTConnectionManager { private MQTTLogger log = MQTTLogger.LOGGER; + private boolean isWill = false; + + private ByteBuf willMessage; + + private String willTopic; + + private int willQoSLevel; + + private boolean willRetain; + public MQTTConnectionManager(MQTTSession session) { this.session = session; MQTTFailureListener failureListener = new MQTTFailureListener(this); @@ -66,7 +75,7 @@ public class MQTTConnectionManager { String clientId = validateClientId(cId, cleanSession); if (clientId == null) { session.getProtocolHandler().sendConnack(MqttConnectReturnCode.CONNECTION_REFUSED_IDENTIFIER_REJECTED); - session.getProtocolHandler().disconnect(); + session.getProtocolHandler().disconnect(true); return; } @@ -78,11 +87,13 @@ public class MQTTConnectionManager { session.setIsClean(cleanSession); if (will) { + isWill = true; byte[] payload = willMessage.getBytes(Charset.forName("UTF-8")); - ByteBuf buf = ByteBufAllocator.DEFAULT.buffer(payload.length); - buf.writeBytes(payload); - ServerMessage w = MQTTUtil.createServerMessageFromByteBuf(session, willTopic, willRetain, willQosLevel, buf); - session.getSessionState().setWillMessage(w); + this.willMessage = ByteBufAllocator.DEFAULT.buffer(payload.length); + this.willMessage.writeBytes(payload); + this.willQoSLevel = willQosLevel; + this.willRetain = willRetain; + this.willTopic = willTopic; } session.getConnection().setConnected(true); @@ -119,18 +130,17 @@ public class MQTTConnectionManager { return (ServerSessionImpl) serverSession; } - synchronized void disconnect() { - if (session == null) { + synchronized void disconnect(boolean failure) { + if (session == null || session.getStopped()) { return; } try { + if (isWill && failure) { + session.getMqttPublishManager().sendInternal(0, willTopic, willQoSLevel, willMessage, willRetain, true); + } session.stop(); session.getConnection().destroy(); - - if (session.getState().isWill()) { - session.getConnectionManager().sendWill(); - } } catch (Exception e) { log.error("Error disconnecting client: " + e.getMessage()); } finally { @@ -144,11 +154,6 @@ public class MQTTConnectionManager { } } - private void sendWill() throws Exception { - session.getServer().getPostOffice().route(session.getSessionState().getWillMessage(), true); - session.getSessionState().deleteWillMessage(); - } - private MQTTSessionState getSessionState(String clientId) throws InterruptedException { /* [MQTT-3.1.2-6] If CleanSession is set to 1, the Client and Server MUST discard any previous Session and * start a new one This Session lasts as long as the Network Connection. State data associated with this Session diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTFailureListener.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTFailureListener.java index 7bd9fadae5..4a98d159aa 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTFailureListener.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTFailureListener.java @@ -34,11 +34,11 @@ public class MQTTFailureListener implements FailureListener { @Override public void connectionFailed(ActiveMQException exception, boolean failedOver) { - connectionManager.disconnect(); + connectionManager.disconnect(true); } @Override public void connectionFailed(ActiveMQException exception, boolean failedOver, String scaleDownTargetNodeID) { - connectionManager.disconnect(); + connectionManager.disconnect(true); } } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java index b3587a3a2d..b084f9d067 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolHandler.java @@ -39,9 +39,9 @@ import io.netty.handler.codec.mqtt.MqttSubAckPayload; import io.netty.handler.codec.mqtt.MqttSubscribeMessage; import io.netty.handler.codec.mqtt.MqttUnsubAckMessage; import io.netty.handler.codec.mqtt.MqttUnsubscribeMessage; +import org.apache.activemq.artemis.api.core.RoutingType; import org.apache.activemq.artemis.api.core.SimpleString; import org.apache.activemq.artemis.core.server.ActiveMQServer; -import org.apache.activemq.artemis.api.core.RoutingType; import org.apache.activemq.artemis.spi.core.protocol.ConnectionEntry; /** @@ -89,7 +89,7 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { public void channelRead(ChannelHandlerContext ctx, Object msg) { try { if (stopped) { - disconnect(); + disconnect(true); return; } @@ -98,7 +98,7 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { // Disconnect if Netty codec failed to decode the stream. if (message.decoderResult().isFailure()) { log.debug("Bad Message Disconnecting Client."); - disconnect(); + disconnect(true); return; } @@ -150,11 +150,11 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { handleDisconnect(message); break; default: - disconnect(); + disconnect(true); } } catch (Exception e) { log.debug("Error processing Control Packet, Disconnecting Client", e); - disconnect(); + disconnect(true); } } @@ -171,8 +171,8 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { session.getConnectionManager().connect(clientId, connect.payload().userName(), connect.payload().password(), connect.variableHeader().isWillFlag(), connect.payload().willMessage(), connect.payload().willTopic(), connect.variableHeader().isWillRetain(), connect.variableHeader().willQos(), connect.variableHeader().isCleanSession()); } - void disconnect() { - session.getConnectionManager().disconnect(); + void disconnect(boolean error) { + session.getConnectionManager().disconnect(error); } void sendConnack(MqttConnectReturnCode returnCode) { @@ -193,7 +193,7 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { void handleConnack(MqttConnAckMessage message) { log.debug("Received invalid CONNACK from client: " + session.getSessionState().getClientId()); log.debug("Disconnecting client: " + session.getSessionState().getClientId()); - disconnect(); + disconnect(true); } void handlePublish(MqttPublishMessage message) throws Exception { @@ -257,7 +257,7 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { } void handleSuback(MqttSubAckMessage message) { - disconnect(); + disconnect(true); } void handleUnsubscribe(MqttUnsubscribeMessage message) throws Exception { @@ -270,7 +270,7 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { } void handleUnsuback(MqttUnsubAckMessage message) { - disconnect(); + disconnect(true); } void handlePingreq(MqttMessage message, ChannelHandlerContext ctx) { @@ -281,13 +281,11 @@ public class MQTTProtocolHandler extends ChannelInboundHandlerAdapter { } void handlePingresp(MqttMessage message) { - disconnect(); + disconnect(true); } void handleDisconnect(MqttMessage message) { - if (session.getSessionState() != null) - session.getState().deleteWillMessage(); - disconnect(); + disconnect(false); } protected int send(int messageId, String topicName, int qosLevel, ByteBuf payload, int deliveryCount) { diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTPublishManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTPublishManager.java index 26886c6db9..76f15c0938 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTPublishManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTPublishManager.java @@ -32,6 +32,7 @@ import org.apache.activemq.artemis.core.server.Queue; import org.apache.activemq.artemis.core.server.ServerConsumer; import org.apache.activemq.artemis.core.server.ServerMessage; import org.apache.activemq.artemis.core.server.impl.ServerMessageImpl; +import org.apache.activemq.artemis.core.transaction.Transaction; /** * Handles MQTT Exactly Once (QoS level 2) Protocol. @@ -133,6 +134,20 @@ public class MQTTPublishManager { // INBOUND void handleMessage(int messageId, String topic, int qos, ByteBuf payload, boolean retain) throws Exception { + sendInternal(messageId, topic, qos, payload, retain, false); + } + + /** + * Sends a message either on behalf of the client or on behalf of the broker (Will Messages) + * @param messageId + * @param topic + * @param qos + * @param payload + * @param retain + * @param internal if true means on behalf of the broker (skips authorisation) and does not return ack. + * @throws Exception + */ + void sendInternal(int messageId, String topic, int qos, ByteBuf payload, boolean retain, boolean internal) throws Exception { synchronized (lock) { ServerMessage serverMessage = MQTTUtil.createServerMessageFromByteBuf(session, topic, retain, qos, payload); @@ -141,17 +156,23 @@ public class MQTTPublishManager { } if (qos < 2 || !state.getPubRec().contains(messageId)) { - if (qos == 2) + if (qos == 2 && !internal) state.getPubRec().add(messageId); - session.getServerSession().send(serverMessage, true); - } - if (retain) { - boolean reset = payload instanceof EmptyByteBuf || payload.capacity() == 0; - session.getRetainMessageManager().handleRetainedMessage(serverMessage, topic, reset); + Transaction tx = session.getServerSession().newTransaction(); + try { + session.getServerSession().send(tx, serverMessage, true, false); + if (retain) { + boolean reset = payload instanceof EmptyByteBuf || payload.capacity() == 0; + session.getRetainMessageManager().handleRetainedMessage(serverMessage, topic, reset, tx); + } + tx.commit(); + } catch (Throwable t) { + tx.rollback(); + throw t; + } + createMessageAck(messageId, qos, internal); } - - createMessageAck(messageId, qos); } } @@ -182,14 +203,16 @@ public class MQTTPublishManager { } } - private void createMessageAck(final int messageId, final int qos) { + private void createMessageAck(final int messageId, final int qos, final boolean internal) { session.getServer().getStorageManager().afterCompleteOperations(new IOCallback() { @Override public void done() { - if (qos == 1) { - session.getProtocolHandler().sendPubAck(messageId); - } else if (qos == 2) { - session.getProtocolHandler().sendPubRec(messageId); + if (!internal) { + if (qos == 1) { + session.getProtocolHandler().sendPubAck(messageId); + } else if (qos == 2) { + session.getProtocolHandler().sendPubRec(messageId); + } } } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTRetainMessageManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTRetainMessageManager.java index 27423d8452..596670bf94 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTRetainMessageManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTRetainMessageManager.java @@ -17,13 +17,15 @@ package org.apache.activemq.artemis.core.protocol.mqtt; -import java.util.Iterator; - import org.apache.activemq.artemis.api.core.SimpleString; import org.apache.activemq.artemis.core.server.BindingQueryResult; import org.apache.activemq.artemis.core.server.MessageReference; import org.apache.activemq.artemis.core.server.Queue; +import org.apache.activemq.artemis.core.server.RoutingContext; import org.apache.activemq.artemis.core.server.ServerMessage; +import org.apache.activemq.artemis.core.server.impl.RoutingContextImpl; +import org.apache.activemq.artemis.core.transaction.Transaction; +import org.apache.activemq.artemis.utils.LinkedListIterator; public class MQTTRetainMessageManager { @@ -42,7 +44,7 @@ public class MQTTRetainMessageManager { * the subscription queue for the consumer. When a new retained message is received the message will be sent to * the retained queue and the previous retain message consumed to remove it from the queue. */ - void handleRetainedMessage(ServerMessage message, String address, boolean reset) throws Exception { + void handleRetainedMessage(ServerMessage message, String address, boolean reset, Transaction tx) throws Exception { SimpleString retainAddress = new SimpleString(MQTTUtil.convertMQTTAddressFilterToCoreRetain(address, session.getWildcardConfiguration())); Queue queue = session.getServer().locateQueue(retainAddress); @@ -50,39 +52,52 @@ public class MQTTRetainMessageManager { queue = session.getServerSession().createQueue(retainAddress, retainAddress, null, false, true); } - // Set the address of this message to the retained queue. - message.setAddress(retainAddress); - Iterator iterator = queue.iterator(); - synchronized (iterator) { - if (iterator.hasNext()) { - Long messageId = iterator.next().getMessage().getMessageID(); - queue.deleteReference(messageId); - } + try (LinkedListIterator iterator = queue.iterator()) { + synchronized (queue) { + if (iterator.hasNext()) { + MessageReference ref = iterator.next(); + iterator.remove(); + queue.acknowledge(tx, ref); + } - if (!reset) { - session.getServerSession().send(message.copy(), true); - } - } - } - - void addRetainedMessagesToQueue(Queue queue, String address) throws Exception { - // Queue to add the retained messages to - - // The address filter that matches all retained message queues. - String retainAddress = MQTTUtil.convertMQTTAddressFilterToCoreRetain(address, session.getWildcardConfiguration()); - BindingQueryResult bindingQueryResult = session.getServerSession().executeBindingQuery(new SimpleString(retainAddress)); - - // Iterate over all matching retain queues and add the head message to the original queue. - for (SimpleString retainedQueueName : bindingQueryResult.getQueueNames()) { - Queue retainedQueue = session.getServer().locateQueue(retainedQueueName); - synchronized (this) { - Iterator i = retainedQueue.iterator(); - if (i.hasNext()) { - ServerMessage message = i.next().getMessage().copy(session.getServer().getStorageManager().generateID()); - queue.addTail(message.createReference(queue), true); + if (!reset) { + sendToQueue(message.copy(session.getServer().getStorageManager().generateID()), queue, tx); } } } } + + // SEND to Queue. + void addRetainedMessagesToQueue(Queue queue, String address) throws Exception { + // The address filter that matches all retained message queues. + String retainAddress = MQTTUtil.convertMQTTAddressFilterToCoreRetain(address, session.getWildcardConfiguration()); + BindingQueryResult bindingQueryResult = session.getServerSession().executeBindingQuery(new SimpleString(retainAddress)); + + // Iterate over all matching retain queues and add the queue + Transaction tx = session.getServerSession().newTransaction(); + try { + synchronized (queue) { + for (SimpleString retainedQueueName : bindingQueryResult.getQueueNames()) { + Queue retainedQueue = session.getServer().locateQueue(retainedQueueName); + try (LinkedListIterator i = retainedQueue.iterator()) { + if (i.hasNext()) { + ServerMessage message = i.next().getMessage().copy(session.getServer().getStorageManager().generateID()); + sendToQueue(message, queue, tx); + } + } + } + } + } catch (Throwable t) { + tx.rollback(); + throw t; + } + tx.commit(); + } + + private void sendToQueue(ServerMessage message, Queue queue, Transaction tx) throws Exception { + RoutingContext context = new RoutingContextImpl(tx); + queue.route(message, context); + session.getServer().getPostOffice().processRoute(message, context, false); + } } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSessionState.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSessionState.java index 9458f8b080..31452bfacc 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSessionState.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSessionState.java @@ -29,14 +29,11 @@ import java.util.concurrent.atomic.AtomicInteger; import io.netty.handler.codec.mqtt.MqttTopicSubscription; import org.apache.activemq.artemis.api.core.Pair; import org.apache.activemq.artemis.core.config.WildcardConfiguration; -import org.apache.activemq.artemis.core.server.ServerMessage; public class MQTTSessionState { private String clientId; - private ServerMessage willMessage; - private final ConcurrentMap subscriptions = new ConcurrentHashMap<>(); // Used to store Packet ID of Publish QoS1 and QoS2 message. See spec: 4.3.3 QoS 2: Exactly once delivery. Method B. @@ -60,7 +57,6 @@ public class MQTTSessionState { addressMessageMap.clear(); pubRec.clear(); outboundStore.clear(); - willMessage = null; } OutboundStore getOutboundStore() { @@ -79,22 +75,6 @@ public class MQTTSessionState { this.attached = attached; } - boolean isWill() { - return willMessage != null; - } - - ServerMessage getWillMessage() { - return willMessage; - } - - void setWillMessage(ServerMessage willMessage) { - this.willMessage = willMessage; - } - - void deleteWillMessage() { - willMessage = null; - } - Collection getSubscriptions() { return subscriptions.values(); } diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTTest.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTTest.java index a26a046b2c..7a12f4261d 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTTest.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTTest.java @@ -37,10 +37,10 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; +import org.apache.activemq.artemis.api.core.RoutingType; import org.apache.activemq.artemis.api.core.SimpleString; import org.apache.activemq.artemis.core.protocol.mqtt.MQTTConnectionManager; import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSession; -import org.apache.activemq.artemis.api.core.RoutingType; import org.apache.activemq.artemis.core.server.impl.AddressInfo; import org.apache.activemq.artemis.tests.util.Wait; import org.fusesource.mqtt.client.BlockingConnection; @@ -1029,6 +1029,43 @@ public class MQTTTest extends MQTTTestSupport { assertEquals("test message", new String(m.getPayload())); } + @Test(timeout = 60 * 1000) + public void testWillMessageIsRetained() throws Exception { + getServer().createQueue(SimpleString.toSimpleString("will"), RoutingType.MULTICAST, SimpleString.toSimpleString("will"), null, true, false); + + MQTT mqtt = createMQTTConnection("1", false); + mqtt.setKeepAlive((short) 1); + mqtt.setWillMessage("test message"); + mqtt.setWillTopic("will"); + mqtt.setWillQos(QoS.AT_LEAST_ONCE); + mqtt.setWillRetain(true); + + final BlockingConnection connection = mqtt.blockingConnection(); + connection.connect(); + Wait.waitFor(new Wait.Condition() { + @Override + public boolean isSatisfied() throws Exception { + return connection.isConnected(); + } + }); + + // kill transport + connection.kill(); + + Thread.sleep(10000); + + MQTT mqtt2 = createMQTTConnection("2", false); + BlockingConnection connection2 = mqtt2.blockingConnection(); + connection2.connect(); + connection2.subscribe(new Topic[]{new Topic("will", QoS.AT_LEAST_ONCE)}); + + Message m = connection2.receive(1000, TimeUnit.MILLISECONDS); + assertNotNull(m); + m.ack(); + assertEquals("test message", new String(m.getPayload())); + } + + @Test(timeout = 60 * 1000) public void testCleanSession() throws Exception { final String CLIENTID = "cleansession";