diff --git a/activemq-broker/src/main/java/org/apache/activemq/broker/region/RegionBroker.java b/activemq-broker/src/main/java/org/apache/activemq/broker/region/RegionBroker.java index 46c6de14f5..da6e1fadb3 100755 --- a/activemq-broker/src/main/java/org/apache/activemq/broker/region/RegionBroker.java +++ b/activemq-broker/src/main/java/org/apache/activemq/broker/region/RegionBroker.java @@ -215,6 +215,10 @@ public class RegionBroker extends EmptyBroker { return brokerService != null ? brokerService.getDestinationPolicy() : null; } + public ConnectionContext getConnectionContext(String clientId) { + return clientIdSet.get(clientId); + } + @Override public void addConnection(ConnectionContext context, ConnectionInfo info) throws Exception { String clientId = info.getClientId(); diff --git a/activemq-broker/src/main/java/org/apache/activemq/broker/region/Topic.java b/activemq-broker/src/main/java/org/apache/activemq/broker/region/Topic.java index 4744af8e3b..0186b4240e 100755 --- a/activemq-broker/src/main/java/org/apache/activemq/broker/region/Topic.java +++ b/activemq-broker/src/main/java/org/apache/activemq/broker/region/Topic.java @@ -305,7 +305,7 @@ public class Topic extends BaseDestination implements Task { sub.remove(context, this, dispatched); } - protected void recoverRetroactiveMessages(ConnectionContext context, Subscription subscription) throws Exception { + public void recoverRetroactiveMessages(ConnectionContext context, Subscription subscription) throws Exception { if (subscription.getConsumerInfo().isRetroactive()) { subscriptionRecoveryPolicy.recover(context, this, subscription); } diff --git a/activemq-broker/src/main/java/org/apache/activemq/broker/region/policy/RetainedMessageSubscriptionRecoveryPolicy.java b/activemq-broker/src/main/java/org/apache/activemq/broker/region/policy/RetainedMessageSubscriptionRecoveryPolicy.java index d350a5f268..ba2d1a14a4 100644 --- a/activemq-broker/src/main/java/org/apache/activemq/broker/region/policy/RetainedMessageSubscriptionRecoveryPolicy.java +++ b/activemq-broker/src/main/java/org/apache/activemq/broker/region/policy/RetainedMessageSubscriptionRecoveryPolicy.java @@ -37,8 +37,8 @@ import org.apache.activemq.filter.DestinationFilter; */ public class RetainedMessageSubscriptionRecoveryPolicy implements SubscriptionRecoveryPolicy { - public static final String RETAIN_PROPERTY = "ActiveMQRetain"; - public static final String RETAINED_PROPERTY = "ActiveMQRetained"; + public static final String RETAIN_PROPERTY = "ActiveMQ.Retain"; + public static final String RETAINED_PROPERTY = "ActiveMQ.Retained"; private volatile MessageReference retainedMessage; private SubscriptionRecoveryPolicy wrapped; diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java index ebb9f455d2..cbb6415d67 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java @@ -19,6 +19,7 @@ package org.apache.activemq.transport.mqtt; import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.zip.DataFormatException; @@ -28,6 +29,10 @@ import javax.jms.JMSException; import javax.jms.Message; import org.apache.activemq.broker.BrokerService; +import org.apache.activemq.broker.ConnectionContext; +import org.apache.activemq.broker.region.RegionBroker; +import org.apache.activemq.broker.region.Subscription; +import org.apache.activemq.broker.region.TopicRegion; import org.apache.activemq.broker.region.policy.RetainedMessageSubscriptionRecoveryPolicy; import org.apache.activemq.command.*; import org.apache.activemq.store.PersistenceAdapterSupport; @@ -80,7 +85,7 @@ public class MQTTProtocolConverter { private String clientId; private long defaultKeepAlive; private int activeMQSubscriptionPrefetch=1; - private final String QOS_PROPERTY_NAME = "QoSPropertyName"; + protected static final String QOS_PROPERTY_NAME = "ActiveMQ.MQTT.QoS"; private final MQTTPacketIdGenerator packetIdGenerator; public MQTTProtocolConverter(MQTTTransport mqttTransport, BrokerService brokerService) { @@ -353,20 +358,21 @@ public class MQTTProtocolConverter { final UTF8Buffer topicName = topic.name(); final QoS topicQoS = topic.qos(); + ActiveMQDestination destination = new ActiveMQTopic(convertMQTTToActiveMQ(topicName.toString())); + if( mqttSubscriptionByTopic.containsKey(topicName)) { if (topicQoS != mqttSubscriptionByTopic.get(topicName).qos()) { // remove old subscription as the QoS has changed onUnSubscribe(topicName); } else { - // duplicate SUBSCRIBE packet - // TODO find all matching topics and resend retained messages + // duplicate SUBSCRIBE packet, find all matching topics and resend retained messages + resendRetainedMessages(topicName, destination); + return (byte) topicQoS.ordinal(); } onUnSubscribe(topicName); } - ActiveMQDestination destination = new ActiveMQTopic(convertMQTTToActiveMQ(topicName.toString())); - ConsumerId id = new ConsumerId(sessionId, consumerIdGenerator.getNextSequenceId()); ConsumerInfo consumerInfo = new ConsumerInfo(id); consumerInfo.setDestination(destination); @@ -402,6 +408,40 @@ public class MQTTProtocolConverter { return qos[0]; } + private void resendRetainedMessages(UTF8Buffer topicName, ActiveMQDestination destination) throws MQTTProtocolException { + // get TopicRegion + RegionBroker regionBroker; + try { + regionBroker = (RegionBroker) brokerService.getBroker().getAdaptor(RegionBroker.class); + } catch (Exception e) { + throw new MQTTProtocolException("Error subscribing to " + topicName + ": " + e.getMessage(), false, e); + } + final TopicRegion topicRegion = (TopicRegion) regionBroker.getTopicRegion(); + + // get all matching Topics + final Set matchingDestinations = topicRegion.getDestinations(destination); + for (org.apache.activemq.broker.region.Destination dest : matchingDestinations) { + // find matching MQTT subscription for this client + final String mqttTopicName = convertActiveMQToMQTT(dest.getName()); + final MQTTSubscription mqttSubscription = mqttSubscriptionByTopic.get(new UTF8Buffer(mqttTopicName)); + if (mqttSubscription != null) { + // recover retroactive messages for matching subscription + final ConsumerInfo consumerInfo = mqttSubscription.getConsumerInfo(); + final ConsumerId consumerId = consumerInfo.getConsumerId(); + final Subscription subscription = topicRegion.getSubscriptions().get(consumerId); + + // use actual client id used to create connection to lookup connection context + final ConnectionContext connectionContext = regionBroker.getConnectionContext(connectionInfo.getClientId()); + try { + ((org.apache.activemq.broker.region.Topic)dest).recoverRetroactiveMessages(connectionContext, subscription); + } catch (Exception e) { + throw new MQTTProtocolException("Error recovering retained messages for " + + mqttTopicName + ": " + e.getMessage(), false, e); + } + } + } + } + void onUnSubscribe(UNSUBSCRIBE command) throws MQTTProtocolException { checkConnected(); UTF8Buffer[] topics = command.topics(); @@ -579,7 +619,7 @@ public class MQTTProtocolConverter { synchronized (mqttTopicMap) { topicName = mqttTopicMap.get(message.getJMSDestination()); if (topicName == null) { - topicName = new UTF8Buffer(message.getDestination().getPhysicalName().replace('.', '/')); + topicName = new UTF8Buffer(convertActiveMQToMQTT(message.getDestination().getPhysicalName())); mqttTopicMap.put(message.getJMSDestination(), topicName); } } @@ -626,6 +666,10 @@ public class MQTTProtocolConverter { return result; } + private String convertActiveMQToMQTT(String physicalName) { + return physicalName.replace('.', '/'); + } + public MQTTTransport getMQTTTransport() { return mqttTransport; } diff --git a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java index 37016b8097..e11f6e9040 100644 --- a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java +++ b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java @@ -44,6 +44,7 @@ import org.apache.activemq.broker.TransportConnector; import org.apache.activemq.broker.region.policy.LastImageSubscriptionRecoveryPolicy; import org.apache.activemq.broker.region.policy.PolicyEntry; import org.apache.activemq.broker.region.policy.PolicyMap; +import org.apache.activemq.broker.region.policy.RetainedMessageSubscriptionRecoveryPolicy; import org.apache.activemq.command.ActiveMQMessage; import org.apache.activemq.command.ActiveMQTopic; import org.apache.activemq.filter.DestinationMapEntry; @@ -414,12 +415,6 @@ public class MQTTTest extends AbstractMQTTTest { msg = connection.receive(5000, TimeUnit.MILLISECONDS); } while (msg != null); - // connection is borked after timeout in connection.receive() - connection.disconnect(); - connection = mqtt.blockingConnection(); - connection.connect(); - connection.subscribe(new Topic[] { new Topic(wildcard, QoS.AT_LEAST_ONCE) }); - // test non-retained message for (String topic : topics) { connection.publish(topic, topic.getBytes(), QoS.AT_LEAST_ONCE, false); @@ -523,6 +518,7 @@ public class MQTTTest extends AbstractMQTTTest { waitCount++; } assertEquals(qos.ordinal(), actualQoS[0]); + actualQoS[0] = -1; } connection.unsubscribe(new String[] { "TopicA" }); @@ -530,6 +526,61 @@ public class MQTTTest extends AbstractMQTTTest { } + @Test(timeout = 60 * 1000) + public void testRetainedMessage() throws Exception { + addMQTTConnector(); + brokerService.start(); + + MQTT mqtt = createMQTTConnection(); + mqtt.setKeepAlive((short) 2); + mqtt.setCleanSession(true); + + final String RETAIN = "RETAIN"; + final String TOPICA = "TopicA"; + + final String[] clientIds = { null, "foo" }; + for (String clientId : clientIds) { + + mqtt.setClientId(clientId); + final BlockingConnection connection = mqtt.blockingConnection(); + connection.connect(); + + // set retained message and check + connection.publish(TOPICA, RETAIN.getBytes(), QoS.EXACTLY_ONCE, true); + connection.subscribe(new Topic[]{new Topic(TOPICA, QoS.AT_MOST_ONCE)}); + Message msg = connection.receive(5000, TimeUnit.MILLISECONDS); + assertNotNull("No retained message for " + clientId, msg); + assertEquals(RETAIN, new String(msg.getPayload())); + msg.ack(); + + // test duplicate subscription + connection.subscribe(new Topic[]{new Topic(TOPICA, QoS.AT_MOST_ONCE)}); + msg = connection.receive(5000, TimeUnit.MILLISECONDS); + assertNotNull("No retained message on duplicate subscription for " + clientId, msg); + assertEquals(RETAIN, new String(msg.getPayload())); + msg.ack(); + connection.unsubscribe(new String[]{"TopicA"}); + + // clear retained message and check that we don't receive it + connection.publish(TOPICA, "".getBytes(), QoS.AT_MOST_ONCE, true); + connection.subscribe(new Topic[]{new Topic(TOPICA, QoS.AT_MOST_ONCE)}); + msg = connection.receive(5000, TimeUnit.MILLISECONDS); + assertNull("Retained message not cleared for " + clientId, msg); + connection.unsubscribe(new String[]{"TopicA"}); + + // set retained message again and check + connection.publish(TOPICA, RETAIN.getBytes(), QoS.EXACTLY_ONCE, true); + connection.subscribe(new Topic[]{new Topic(TOPICA, QoS.AT_MOST_ONCE)}); + msg = connection.receive(5000, TimeUnit.MILLISECONDS); + assertNotNull("No reset retained message for " + clientId, msg); + assertEquals(RETAIN, new String(msg.getPayload())); + msg.ack(); + connection.unsubscribe(new String[]{"TopicA"}); + + connection.disconnect(); + } + } + @Test(timeout = 60 * 1000) public void testFailedSubscription() throws Exception { addMQTTConnector(); @@ -966,18 +1017,31 @@ public class MQTTTest extends AbstractMQTTTest { initializeConnection(provider); final String DESTINATION_NAME = "foo.*"; + // send retained message + final String RETAINED = "RETAINED"; + provider.publish("foo/bah", RETAINED.getBytes(), AT_LEAST_ONCE, true); + ActiveMQConnection activeMQConnection = (ActiveMQConnection) new ActiveMQConnectionFactory(openwireTransport.getConnectUri()).createConnection(); + // MUST set to true to receive retained messages + activeMQConnection.setUseRetroactiveConsumer(true); activeMQConnection.start(); Session s = activeMQConnection.createSession(false, Session.AUTO_ACKNOWLEDGE); javax.jms.Topic jmsTopic = s.createTopic(DESTINATION_NAME); MessageConsumer consumer = s.createConsumer(jmsTopic); + // check whether we received retained message on JMS subscribe + ActiveMQMessage message = (ActiveMQMessage) consumer.receive(5000); + assertNotNull("Should get retained message", message); + ByteSequence bs = message.getContent(); + assertEquals(RETAINED, new String(bs.data, bs.offset, bs.length)); + assertTrue(message.getBooleanProperty(RetainedMessageSubscriptionRecoveryPolicy.RETAINED_PROPERTY)); + for (int i = 0; i < numberOfMessages; i++) { String payload = "Test Message: " + i; provider.publish("foo/bah", payload.getBytes(), AT_LEAST_ONCE); - ActiveMQMessage message = (ActiveMQMessage) consumer.receive(5000); + message = (ActiveMQMessage) consumer.receive(5000); assertNotNull("Should get a message", message); - ByteSequence bs = message.getContent(); + bs = message.getContent(); assertEquals(payload, new String(bs.data, bs.offset, bs.length)); } @@ -994,17 +1058,31 @@ public class MQTTTest extends AbstractMQTTTest { initializeConnection(provider); ActiveMQConnection activeMQConnection = (ActiveMQConnection) new ActiveMQConnectionFactory(openwireTransport.getConnectUri()).createConnection(); + activeMQConnection.setUseRetroactiveConsumer(true); activeMQConnection.start(); Session s = activeMQConnection.createSession(false, Session.AUTO_ACKNOWLEDGE); javax.jms.Topic jmsTopic = s.createTopic("foo.far"); MessageProducer producer = s.createProducer(jmsTopic); + // send retained message from JMS + final String RETAINED = "RETAINED"; + TextMessage sendMessage = s.createTextMessage(RETAINED); + // mark the message to be retained + sendMessage.setBooleanProperty(RetainedMessageSubscriptionRecoveryPolicy.RETAIN_PROPERTY, true); + // MQTT QoS can be set using MQTTProtocolConverter.QOS_PROPERTY_NAME property + sendMessage.setIntProperty(MQTTProtocolConverter.QOS_PROPERTY_NAME, 0); + producer.send(sendMessage); + provider.subscribe("foo/+", AT_MOST_ONCE); + byte[] message = provider.receive(10000); + assertNotNull("Should get retained message", message); + assertEquals(RETAINED, new String(message)); + for (int i = 0; i < numberOfMessages; i++) { String payload = "This is Test Message: " + i; - TextMessage sendMessage = s.createTextMessage(payload); + sendMessage = s.createTextMessage(payload); producer.send(sendMessage); - byte[] message = provider.receive(5000); + message = provider.receive(5000); assertNotNull("Should get a message", message); assertEquals(payload, new String(message));