diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java index f0a7f2ecc9..bc762e88e9 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java @@ -19,7 +19,6 @@ package org.apache.activemq.artemis.core.protocol.mqtt; import java.util.EnumSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -34,7 +33,6 @@ import org.apache.activemq.artemis.api.core.RoutingType; import org.apache.activemq.artemis.api.core.SimpleString; import org.apache.activemq.artemis.core.server.ActiveMQMessageBundle; import org.apache.activemq.artemis.core.server.BindingQueryResult; -import org.apache.activemq.artemis.core.server.Consumer; import org.apache.activemq.artemis.core.server.Queue; import org.apache.activemq.artemis.core.server.ServerConsumer; import org.apache.activemq.artemis.core.server.impl.AddressInfo; @@ -104,12 +102,7 @@ public class MQTTSubscriptionManager { private void addSubscription(MqttTopicSubscription subscription, Integer subscriptionIdentifier, boolean initialStart) throws Exception { String rawTopicName = CompositeAddress.extractAddressName(subscription.topicName()); - String parsedTopicName = rawTopicName; - - // if using a shared subscription then parse - if (rawTopicName.startsWith(MQTTUtil.SHARED_SUBSCRIPTION_PREFIX)) { - parsedTopicName = rawTopicName.substring(rawTopicName.indexOf(SLASH, rawTopicName.indexOf(SLASH) + 1) + 1); - } + String parsedTopicName = parseTopicName(rawTopicName); int qos = subscription.qualityOfService().value(); String coreAddress = MQTTUtil.convertMqttTopicFilterToCoreAddress(parsedTopicName, session.getWildcardConfiguration()); @@ -138,6 +131,16 @@ public class MQTTSubscriptionManager { } } + private String parseTopicName(String rawTopicName) { + String parsedTopicName = rawTopicName; + + // if using a shared subscription then parse + if (rawTopicName.startsWith(MQTTUtil.SHARED_SUBSCRIPTION_PREFIX)) { + parsedTopicName = rawTopicName.substring(rawTopicName.indexOf(SLASH, rawTopicName.indexOf(SLASH) + 1) + 1); + } + return parsedTopicName; + } + synchronized void stop() throws Exception { for (ServerConsumer consumer : consumers.values()) { consumer.setStarted(false); @@ -227,7 +230,7 @@ public class MQTTSubscriptionManager { // for noLocal support we use the MQTT *client id* rather than the connection ID, but we still use the existing property name ServerConsumer consumer = session.getServerSession().createConsumer(cid, queue.getName(), noLocal ? SimpleString.toSimpleString(CONNECTION_ID_PROPERTY_NAME_STRING + " <> '" + session.getState().getClientId() + "'") : null, false, false, -1); - ServerConsumer existingConsumer = consumers.put(topic, consumer); + ServerConsumer existingConsumer = consumers.put(parseTopicName(topic), consumer); if (existingConsumer != null) { existingConsumer.setStarted(false); existingConsumer.close(false); @@ -255,45 +258,28 @@ public class MQTTSubscriptionManager { return removeSubscription(address, true); } - private short removeSubscription(String address, boolean enforceSecurity) { - if (session.getState().getSubscription(address) == null) { + private short removeSubscription(String topic, boolean enforceSecurity) { + if (session.getState().getSubscription(topic) == null) { return MQTTReasonCodes.NO_SUBSCRIPTION_EXISTED; } short reasonCode = MQTTReasonCodes.SUCCESS; try { - SimpleString internalQueueName = getQueueNameForTopic(address); - session.getState().removeSubscription(address); + session.getState().removeSubscription(topic); - Queue queue = session.getServer().locateQueue(internalQueueName); - AddressInfo addressInfo = session.getServerSession().getAddress(SimpleString.toSimpleString(MQTTUtil.convertMqttTopicFilterToCoreAddress(address, session.getWildcardConfiguration()))); - if (addressInfo != null && addressInfo.getRoutingTypes().contains(RoutingType.ANYCAST)) { - ServerConsumer consumer = consumers.get(address); - consumers.remove(address); - if (consumer != null) { - consumer.close(false); - consumerQoSLevels.remove(consumer.getID()); - } - } else { - consumers.remove(address); - Set queueConsumers; - if (queue != null && (queueConsumers = (Set) queue.getConsumers()) != null) { - for (Consumer consumer : queueConsumers) { - if (consumer instanceof ServerConsumer) { - ((ServerConsumer) consumer).close(false); - consumerQoSLevels.remove(((ServerConsumer) consumer).getID()); - } - } - } + ServerConsumer removed = consumers.remove(parseTopicName(topic)); + if (removed != null) { + removed.close(false); + consumerQoSLevels.remove(removed.getID()); } + SimpleString internalQueueName = getQueueNameForTopic(topic); + Queue queue = session.getServer().locateQueue(internalQueueName); if (queue != null) { - assert session.getServerSession().executeQueueQuery(internalQueueName).isExists(); - if (queue.isConfigurationManaged()) { queue.deleteAllReferences(); - } else { + } else if (!topic.startsWith(MQTTUtil.SHARED_SUBSCRIPTION_PREFIX) || (topic.startsWith(MQTTUtil.SHARED_SUBSCRIPTION_PREFIX) && queue.getConsumerCount() == 0)) { session.getServerSession().deleteQueue(internalQueueName, enforceSecurity); } } diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5Test.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5Test.java index 26f00e7a18..5679ae8168 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5Test.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5Test.java @@ -36,8 +36,10 @@ import org.apache.activemq.artemis.core.server.Queue; import org.apache.activemq.artemis.core.settings.impl.AddressSettings; import org.apache.activemq.artemis.logs.AssertionLoggerHandler; import org.apache.activemq.artemis.tests.util.RandomUtil; +import org.apache.activemq.artemis.utils.ReusableLatch; import org.apache.activemq.artemis.utils.Wait; import org.eclipse.paho.mqttv5.client.MqttAsyncClient; +import org.eclipse.paho.mqttv5.client.MqttCallback; import org.eclipse.paho.mqttv5.client.MqttClient; import org.eclipse.paho.mqttv5.client.MqttConnectionOptions; import org.eclipse.paho.mqttv5.client.MqttConnectionOptionsBuilder; @@ -370,4 +372,79 @@ public class MQTT5Test extends MQTT5TestSupport { consumer.disconnect(); consumer.close(); } + + @Test(timeout = DEFAULT_TIMEOUT) + public void testSharedSubscriptionQueueRemoval() throws Exception { + final String TOPIC = "myTopic"; + final String SUB_NAME = "myShare"; + final String SHARED_SUB = MQTTUtil.SHARED_SUBSCRIPTION_PREFIX + SUB_NAME + "/" + TOPIC; + ReusableLatch ackLatch = new ReusableLatch(1); + + MqttCallback mqttCallback = new DefaultMqttCallback() { + @Override + public void messageArrived(String topic, org.eclipse.paho.mqttv5.common.MqttMessage message) throws Exception { + ackLatch.countDown(); + } + }; + + // create consumer 1 + MqttClient consumer1 = createPahoClient("consumer1"); + consumer1.connect(); + consumer1.setCallback(mqttCallback); + consumer1.subscribe(SHARED_SUB, 1); + + // create consumer 2 + MqttClient consumer2 = createPahoClient("consumer2"); + consumer2.connect(); + consumer2.setCallback(mqttCallback); + consumer2.subscribe(SHARED_SUB, 1); + + // verify there are 2 consumers on the subscription queue + Queue sharedSubQueue = server.locateQueue(SUB_NAME.concat(".").concat(TOPIC)); + assertNotNull(sharedSubQueue); + assertEquals(TOPIC, sharedSubQueue.getAddress().toString()); + assertEquals(2, sharedSubQueue.getConsumerCount()); + + // send a message + MqttClient producer = createPahoClient("producer"); + producer.connect(); + producer.publish(TOPIC, new byte[0], 1, false); + + // ensure one of the consumers receives the message + assertTrue(ackLatch.await(2, TimeUnit.SECONDS)); + + // disconnect the first consumer + consumer1.disconnect(); + + // verify there is 1 consumer on the subscription queue + sharedSubQueue = server.locateQueue(SUB_NAME.concat(".").concat(TOPIC)); + assertNotNull(sharedSubQueue); + assertEquals(TOPIC, sharedSubQueue.getAddress().toString()); + assertEquals(1, sharedSubQueue.getConsumerCount()); + + // send a message and ensure the remaining consumer receives it + ackLatch.countUp(); + producer.publish(TOPIC, new byte[0], 1, false); + assertTrue(ackLatch.await(2, TimeUnit.SECONDS)); + + // reconnect previous consumer + consumer1.connect(); + consumer1.setCallback(mqttCallback); + consumer1.subscribe(SHARED_SUB, 1); + + // send a message and ensure one of the consumers receives it + ackLatch.countUp(); + producer.publish(TOPIC, new byte[0], 1, false); + assertTrue(ackLatch.await(2, TimeUnit.SECONDS)); + + producer.disconnect(); + producer.close(); + consumer1.disconnect(); + consumer1.close(); + consumer2.disconnect(); + consumer2.close(); + + // verify the shared subscription queue is removed after all the subscribers disconnect + Wait.assertTrue(() -> server.locateQueue(SUB_NAME.concat(".").concat(TOPIC)) == null, 2000, 100); + } } diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5TestSupport.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5TestSupport.java index 6d4abc9c6a..8482f9573b 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5TestSupport.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt5/MQTT5TestSupport.java @@ -19,6 +19,7 @@ package org.apache.activemq.artemis.tests.integration.mqtt5; import javax.jms.ConnectionFactory; import java.io.File; import java.io.IOException; +import java.lang.invoke.MethodHandles; import java.security.ProtectionDomain; import java.util.Arrays; import java.util.Collection; @@ -74,7 +75,6 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.lang.invoke.MethodHandles; import static java.util.Collections.singletonList; import static org.apache.activemq.artemis.core.protocol.mqtt.MQTTProtocolManagerFactory.MQTT_PROTOCOL_NAME;