diff --git a/artemis-commons/src/main/java/org/apache/activemq/artemis/api/core/SimpleString.java b/artemis-commons/src/main/java/org/apache/activemq/artemis/api/core/SimpleString.java index bce501c5fe..51a76fe35f 100644 --- a/artemis-commons/src/main/java/org/apache/activemq/artemis/api/core/SimpleString.java +++ b/artemis-commons/src/main/java/org/apache/activemq/artemis/api/core/SimpleString.java @@ -251,6 +251,16 @@ public final class SimpleString implements CharSequence, Serializable, Comparabl return true; } + /** + * returns true if the SimpleString parameter starts with the same char. false if not. + * + * @param other the char to look for + * @return true if this SimpleString starts with the same data + */ + public boolean startsWith(final char other) { + return data.length > 0 && data[0] == other; + } + @Override public String toString() { if (str == null) { diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java index 0eed8f68e8..6b1f437afa 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTSubscriptionManager.java @@ -43,6 +43,10 @@ import org.apache.activemq.artemis.utils.CompositeAddress; import org.jboss.logging.Logger; import static io.netty.handler.codec.mqtt.MqttProperties.MqttPropertyType.SUBSCRIPTION_IDENTIFIER; +import static org.apache.activemq.artemis.core.protocol.mqtt.MQTTUtil.DOLLAR; +import static org.apache.activemq.artemis.core.protocol.mqtt.MQTTUtil.HASH; +import static org.apache.activemq.artemis.core.protocol.mqtt.MQTTUtil.PLUS; +import static org.apache.activemq.artemis.core.protocol.mqtt.MQTTUtil.SLASH; import static org.apache.activemq.artemis.reader.MessageUtil.CONNECTION_ID_PROPERTY_NAME_STRING; public class MQTTSubscriptionManager { @@ -56,12 +60,16 @@ public class MQTTSubscriptionManager { private final ConcurrentMap consumers; /* - * We filter out certain messages (e.g. management messages, notifications, and messages from any address starting - * with '$'). This is because MQTT clients can do silly things like subscribe to '#' which matches ever address - * on the broker. + * We filter out certain messages (e.g. management messages, notifications) */ private final SimpleString messageFilter; + /* + * We can also filter out messages from any address starting with '$'. This is because MQTT clients can do silly + * things like subscribe to '#' which matches ever address on the broker. + */ + private final SimpleString messageFilterNoDollar; + public MQTTSubscriptionManager(MQTTSession session) { this.session = session; @@ -69,19 +77,21 @@ public class MQTTSubscriptionManager { consumerQoSLevels = new ConcurrentHashMap<>(); // Create filter string to ignore certain messages - StringBuilder builder = new StringBuilder(); - builder.append("NOT (("); - builder.append(FilterConstants.ACTIVEMQ_ADDRESS); - builder.append(" = '"); - builder.append(session.getServer().getConfiguration().getManagementAddress()); - builder.append("') OR ("); - builder.append(FilterConstants.ACTIVEMQ_ADDRESS); - builder.append(" = '"); - builder.append(session.getServer().getConfiguration().getManagementNotificationAddress()); - builder.append("') OR ("); - builder.append(FilterConstants.ACTIVEMQ_ADDRESS); - builder.append(" LIKE '$%'))"); // [MQTT-4.7.2-1] - messageFilter = new SimpleString(builder.toString()); + StringBuilder baseFilter = new StringBuilder(); + baseFilter.append("NOT ("); + baseFilter.append("(").append(FilterConstants.ACTIVEMQ_ADDRESS).append(" = '").append(session.getServer().getConfiguration().getManagementAddress()).append("')"); + baseFilter.append(" OR "); + baseFilter.append("(").append(FilterConstants.ACTIVEMQ_ADDRESS).append(" = '").append(session.getServer().getConfiguration().getManagementNotificationAddress()).append("')"); + + StringBuilder messageFilter = new StringBuilder(baseFilter); + messageFilter.append(")"); + this.messageFilter = new SimpleString(messageFilter.toString()); + + StringBuilder messageFilterNoDollar = new StringBuilder(baseFilter); + messageFilterNoDollar.append(" OR "); + messageFilterNoDollar.append("(").append(FilterConstants.ACTIVEMQ_ADDRESS).append(" LIKE '").append(DOLLAR).append("%')"); // [MQTT-4.7.2-1] + messageFilterNoDollar.append(")"); + this.messageFilterNoDollar = new SimpleString(messageFilterNoDollar.toString()); } synchronized void start() throws Exception { @@ -96,9 +106,9 @@ public class MQTTSubscriptionManager { // if using a shared subscription then parse the subscription name and topic if (topicName.startsWith(MQTTUtil.SHARED_SUBSCRIPTION_PREFIX)) { - int slashIndex = topicName.indexOf("/") + 1; - sharedSubscriptionName = topicName.substring(slashIndex, topicName.indexOf("/", slashIndex)); - topicName = topicName.substring(topicName.indexOf("/", slashIndex) + 1); + int slashIndex = topicName.indexOf(SLASH) + 1; + sharedSubscriptionName = topicName.substring(slashIndex, topicName.indexOf(SLASH, slashIndex)); + topicName = topicName.substring(topicName.indexOf(SLASH, slashIndex) + 1); } int qos = subscription.qualityOfService().value(); String coreAddress = MQTTUtil.convertMqttTopicFilterToCoreAddress(topicName, session.getWildcardConfiguration()); @@ -175,7 +185,7 @@ public class MQTTSubscriptionManager { private Queue findOrCreateQueue(BindingQueryResult bindingQueryResult, AddressInfo addressInfo, SimpleString queue, int qos) throws Exception { if (addressInfo.getRoutingTypes().contains(RoutingType.MULTICAST)) { - return session.getServerSession().createQueue(new QueueConfiguration(queue).setAddress(addressInfo.getName()).setFilterString(messageFilter).setDurable(MQTTUtil.DURABLE_MESSAGES && qos >= 0)); + return session.getServerSession().createQueue(new QueueConfiguration(queue).setAddress(addressInfo.getName()).setFilterString(getMessageFilter(addressInfo.getName())).setDurable(MQTTUtil.DURABLE_MESSAGES && qos >= 0)); } if (addressInfo.getRoutingTypes().contains(RoutingType.ANYCAST)) { @@ -191,7 +201,7 @@ public class MQTTSubscriptionManager { return session.getServer().locateQueue(name); } else { try { - return session.getServerSession().createQueue(new QueueConfiguration(addressInfo.getName()).setRoutingType(RoutingType.ANYCAST).setFilterString(messageFilter).setDurable(MQTTUtil.DURABLE_MESSAGES && qos >= 0)); + return session.getServerSession().createQueue(new QueueConfiguration(addressInfo.getName()).setRoutingType(RoutingType.ANYCAST).setFilterString(getMessageFilter(addressInfo.getName())).setDurable(MQTTUtil.DURABLE_MESSAGES && qos >= 0)); } catch (ActiveMQQueueExistsException e) { return session.getServer().locateQueue(addressInfo.getName()); } @@ -201,6 +211,14 @@ public class MQTTSubscriptionManager { throw ActiveMQMessageBundle.BUNDLE.invalidRoutingTypeForAddress(addressInfo.getRoutingType(), addressInfo.getName().toString(), EnumSet.allOf(RoutingType.class)); } + private SimpleString getMessageFilter(SimpleString topicFilter) { + if (topicFilter.startsWith(PLUS) || topicFilter.startsWith(HASH)) { + return messageFilterNoDollar; + } else { + return messageFilter; + } + } + private void createConsumerForSubscriptionQueue(Queue queue, String topic, int qos, boolean noLocal, Long existingConsumerId) throws Exception { long cid = existingConsumerId != null ? existingConsumerId : session.getServer().getStorageManager().generateID(); diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTUtil.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTUtil.java index 7b4582321b..ba9a476e33 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTUtil.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTUtil.java @@ -77,7 +77,15 @@ public class MQTTUtil { public static final boolean SESSION_AUTO_CREATE_QUEUE = false; - public static final String MQTT_RETAIN_ADDRESS_PREFIX = "$sys.mqtt.retain."; + public static final char DOLLAR = '$'; + + public static final char HASH = '#'; + + public static final char PLUS = '+'; + + public static final char SLASH = '/'; + + public static final String MQTT_RETAIN_ADDRESS_PREFIX = DOLLAR + "sys.mqtt.retain."; public static final SimpleString MQTT_QOS_LEVEL_KEY = SimpleString.toSimpleString("mqtt.qos.level"); @@ -101,9 +109,9 @@ public class MQTTUtil { public static final SimpleString MQTT_CONTENT_TYPE_KEY = SimpleString.toSimpleString("mqtt.content.type"); - public static final String MANAGEMENT_QUEUE_PREFIX = "$sys.mqtt.queue.qos2."; + public static final String MANAGEMENT_QUEUE_PREFIX = DOLLAR + "sys.mqtt.queue.qos2."; - public static final String SHARED_SUBSCRIPTION_PREFIX = "$share/"; + public static final String SHARED_SUBSCRIPTION_PREFIX = DOLLAR + "share/"; public static final long FOUR_BYTE_INT_MAX = Long.decode("0xFFFFFFFF"); // 4_294_967_295 @@ -154,9 +162,9 @@ public class MQTTUtil { public static class MQTTWildcardConfiguration extends WildcardConfiguration { public MQTTWildcardConfiguration() { - setDelimiter('/'); - setSingleWord('+'); - setAnyWords('#'); + setDelimiter(SLASH); + setSingleWord(PLUS); + setAnyWords(HASH); } } diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/PahoMQTTTest.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/PahoMQTTTest.java index 54f85ec406..9124ba41f0 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/PahoMQTTTest.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/PahoMQTTTest.java @@ -190,4 +190,80 @@ public class PahoMQTTTest extends MQTTTestSupport { return new MqttClient(protocol + "://localhost:" + getPort(), clientId, new MemoryPersistence()); } + /* + * This test was adapted from a test from Eclipse Kapua submitted by a community member. + */ + @Test(timeout = 300000) + public void testDollarAndHashSubscriptions() throws Exception { + final String CLIENT_ID_ADMIN = "test-client-admin"; + final String CLIENT_ID_1 = "test-client-1"; + final String CLIENT_ID_2 = "test-client-2"; + + CountDownLatch clientAdminLatch = new CountDownLatch(3); + CountDownLatch client1Latch = new CountDownLatch(2); + CountDownLatch client2Latch = new CountDownLatch(1); + + MqttClient clientAdmin = createPahoClient(CLIENT_ID_ADMIN); + MqttClient client1 = createPahoClient(CLIENT_ID_1); + MqttClient client2 = createPahoClient(CLIENT_ID_2); + + clientAdmin.setCallback(new TestMqttClientCallback(clientAdminLatch)); + client1.setCallback(new TestMqttClientCallback(client1Latch)); + client2.setCallback(new TestMqttClientCallback(client2Latch)); + + clientAdmin.connect(); + client1.connect(); + client2.connect(); + + client1.subscribe("$dollar/" + CLIENT_ID_1 + "/#"); + client2.subscribe("$dollar/" + CLIENT_ID_2 + "/#"); + clientAdmin.subscribe("#"); + + MqttMessage m = new MqttMessage("test".getBytes()); + + client1.publish("$dollar/" + CLIENT_ID_1 + "/foo", m); + client2.publish("$dollar/" + CLIENT_ID_2 + "/foo", m); + clientAdmin.publish("$dollar/" + CLIENT_ID_1 + "/bar", m); + clientAdmin.publish("$dollar/" + CLIENT_ID_1 + "/bar", m); + + client1.publish("$dollar/" + CLIENT_ID_1 + "/baz", m); + client2.publish("$dollar/" + CLIENT_ID_2 + "/baz", m); + clientAdmin.publish("$dollar/" + CLIENT_ID_1 + "/baz", m); + clientAdmin.publish("$dollar/" + CLIENT_ID_2 + "/baz", m); + + assertTrue(client1Latch.await(2, TimeUnit.SECONDS)); + assertTrue(client2Latch.await(2, TimeUnit.SECONDS)); + assertFalse(clientAdminLatch.await(1, TimeUnit.SECONDS)); + assertEquals(3, clientAdminLatch.getCount()); + + clientAdmin.disconnect(); + clientAdmin.close(); + client1.disconnect(); + client1.close(); + client2.disconnect(); + client2.close(); + } + + private class TestMqttClientCallback implements MqttCallback { + + private CountDownLatch latch; + + TestMqttClientCallback(CountDownLatch latch) { + this.latch = latch; + } + + @Override + public void messageArrived(String topic, MqttMessage message) throws Exception { + latch.countDown(); + } + + @Override + public void deliveryComplete(IMqttDeliveryToken token) { + } + + @Override + public void connectionLost(Throwable cause) { + } + } + }