diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTPacketIdGenerator.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTPacketIdGenerator.java new file mode 100644 index 0000000000..bf57f1c9f4 --- /dev/null +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTPacketIdGenerator.java @@ -0,0 +1,176 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.activemq.transport.mqtt; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.activemq.Service; +import org.apache.activemq.broker.BrokerService; +import org.apache.activemq.command.ActiveMQMessage; +import org.apache.activemq.util.LRUCache; +import org.apache.activemq.util.ServiceStopper; +import org.apache.activemq.util.ServiceSupport; +import org.fusesource.mqtt.codec.PUBLISH; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Manages PUBLISH packet ids for clients. + * + * @author Dhiraj Bokde + */ +public class MQTTPacketIdGenerator extends ServiceSupport { + + private static final Logger LOG = LoggerFactory.getLogger(MQTTPacketIdGenerator.class); + private static final Object LOCK = new Object(); + + Map clientIdMap = new ConcurrentHashMap(); + + private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator(); + + private MQTTPacketIdGenerator() { + } + + @Override + protected void doStop(ServiceStopper stopper) throws Exception { + synchronized (this) { + clientIdMap = new ConcurrentHashMap(); + } + } + + @Override + protected void doStart() throws Exception { + } + + public void startClientSession(String clientId) { + if (!clientIdMap.containsKey(clientId)) { + clientIdMap.put(clientId, new PacketIdMaps()); + } + } + + public boolean stopClientSession(String clientId) { + return clientIdMap.remove(clientId) != null; + } + + public short setPacketId(String clientId, MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) { + final PacketIdMaps idMaps = clientIdMap.get(clientId); + if (idMaps == null) { + // maybe its a cleansession=true client id, use session less message id + final short id = messageIdGenerator.getNextSequenceId(); + publish.messageId(id); + return id; + } else { + return idMaps.setPacketId(subscription, message, publish); + } + } + + public void ackPacketId(String clientId, short packetId) { + final PacketIdMaps idMaps = clientIdMap.get(clientId); + if (idMaps != null) { + idMaps.ackPacketId(packetId); + } + } + + public short getNextSequenceId(String clientId) { + final PacketIdMaps idMaps = clientIdMap.get(clientId); + return idMaps != null ? idMaps.getNextSequenceId(): messageIdGenerator.getNextSequenceId(); + } + + public static MQTTPacketIdGenerator getMQTTPacketIdGenerator(BrokerService broker) { + MQTTPacketIdGenerator result = null; + if (broker != null) { + synchronized (LOCK) { + Service[] services = broker.getServices(); + if (services != null) { + for (Service service : services) { + if (service instanceof MQTTPacketIdGenerator) { + return (MQTTPacketIdGenerator) service; + } + } + } + result = new MQTTPacketIdGenerator(); + broker.addService(result); + if (broker.isStarted()) { + try { + result.start(); + } catch (Exception e) { + LOG.warn("Couldn't start MQTTPacketIdGenerator"); + } + } + } + } + + + return result; + } + + private class PacketIdMaps { + + private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator(); + final Map activemqToPacketIds = new LRUCache(MQTTProtocolConverter.DEFAULT_CACHE_SIZE); + final Map packetIdsToActivemq = new LRUCache(MQTTProtocolConverter.DEFAULT_CACHE_SIZE); + + short setPacketId(MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) { + // subscription key + final StringBuilder subscriptionKey = new StringBuilder(); + subscriptionKey.append(subscription.getConsumerInfo().getDestination().getPhysicalName()) + .append(':').append(message.getJMSMessageID()); + final String keyStr = subscriptionKey.toString(); + Short packetId; + synchronized (activemqToPacketIds) { + packetId = activemqToPacketIds.get(keyStr); + if (packetId == null) { + packetId = getNextSequenceId(); + activemqToPacketIds.put(keyStr, packetId); + packetIdsToActivemq.put(packetId, keyStr); + } else { + // mark publish as duplicate! + publish.dup(true); + } + } + publish.messageId(packetId); + return packetId; + } + + void ackPacketId(short packetId) { + synchronized (activemqToPacketIds) { + final String subscriptionKey = packetIdsToActivemq.remove(packetId); + if (subscriptionKey != null) { + activemqToPacketIds.remove(subscriptionKey); + } + } + } + + short getNextSequenceId() { + return messageIdGenerator.getNextSequenceId(); + } + + } + + private class NonZeroSequenceGenerator { + + private short lastSequenceId; + + public synchronized short getNextSequenceId() { + final short val = ++lastSequenceId; + return val != 0 ? val : ++lastSequenceId; + } + + } + +} diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java index 014c6f6cb9..0e590f0ad4 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java @@ -51,13 +51,12 @@ public class MQTTProtocolConverter { private static final IdGenerator CONNECTION_ID_GENERATOR = new IdGenerator(); private static final MQTTFrame PING_RESP_FRAME = new PINGRESP().encode(); private static final double MQTT_KEEP_ALIVE_GRACE_PERIOD= 0.5; - private static final int DEFAULT_CACHE_SIZE = 5000; + static final int DEFAULT_CACHE_SIZE = 5000; private static final byte SUBSCRIBE_ERROR = (byte) 0x80; private final ConnectionId connectionId = new ConnectionId(CONNECTION_ID_GENERATOR.generateId()); private final SessionId sessionId = new SessionId(connectionId, -1); private final ProducerId producerId = new ProducerId(sessionId, 1); - private final LongSequenceGenerator messageIdGenerator = new LongSequenceGenerator(); private final LongSequenceGenerator publisherIdGenerator = new LongSequenceGenerator(); private final LongSequenceGenerator consumerIdGenerator = new LongSequenceGenerator(); @@ -68,8 +67,6 @@ public class MQTTProtocolConverter { private final Map mqttTopicMap = new LRUCache(DEFAULT_CACHE_SIZE); private final Map consumerAcks = new LRUCache(DEFAULT_CACHE_SIZE); private final Map publisherRecs = new LRUCache(DEFAULT_CACHE_SIZE); - private final Map activemqToPacketIds = new LRUCache(DEFAULT_CACHE_SIZE); - private final Map packetIdsToActivemq = new LRUCache(DEFAULT_CACHE_SIZE); private final MQTTTransport mqttTransport; private final BrokerService brokerService; @@ -84,11 +81,13 @@ public class MQTTProtocolConverter { private int activeMQSubscriptionPrefetch=1; private final String QOS_PROPERTY_NAME = "QoSPropertyName"; private final MQTTRetainedMessages retainedMessages; + private final MQTTPacketIdGenerator packetIdGenerator; public MQTTProtocolConverter(MQTTTransport mqttTransport, BrokerService brokerService) { this.mqttTransport = mqttTransport; this.brokerService = brokerService; this.retainedMessages = MQTTRetainedMessages.getMQTTRetainedMessages(brokerService); + this.packetIdGenerator = MQTTPacketIdGenerator.getMQTTPacketIdGenerator(brokerService); this.defaultKeepAlive = 0; } @@ -276,8 +275,10 @@ public class MQTTProtocolConverter { List subs = PersistenceAdapterSupport.listSubscriptions(brokerService.getPersistenceAdapter(), connectionInfo.getClientId()); if( connect.cleanSession() ) { + packetIdGenerator.stopClientSession(getClientId()); deleteDurableSubs(subs); } else { + packetIdGenerator.startClientSession(getClientId()); restoreDurableSubs(subs); } } @@ -363,7 +364,7 @@ public class MQTTProtocolConverter { switch (retainedCopy.qos()) { case AT_LEAST_ONCE: case EXACTLY_ONCE: - retainedCopy.messageId(getNextSequenceId()); + retainedCopy.messageId(packetIdGenerator.getNextSequenceId(getClientId())); case AT_MOST_ONCE: } getMQTTTransport().sendToMQTT(retainedCopy.encode()); @@ -517,7 +518,7 @@ public class MQTTProtocolConverter { void onMQTTPubAck(PUBACK command) { short messageId = command.messageId(); - ackPacketId(messageId); + packetIdGenerator.ackPacketId(getClientId(), messageId); MessageAck ack; synchronized (consumerAcks) { ack = consumerAcks.remove(messageId); @@ -549,7 +550,7 @@ public class MQTTProtocolConverter { void onMQTTPubComp(PUBCOMP command) { short messageId = command.messageId(); - ackPacketId(messageId); + packetIdGenerator.ackPacketId(getClientId(), messageId); MessageAck ack; synchronized (consumerAcks) { ack = consumerAcks.remove(messageId); @@ -662,7 +663,7 @@ public class MQTTProtocolConverter { PUBLISH publish = new PUBLISH(); publish.topicName(connect.willTopic()); publish.qos(connect.willQos()); - publish.messageId(getNextSequenceId()); + publish.messageId(packetIdGenerator.getNextSequenceId(getClientId())); publish.payload(connect.willMessage()); ActiveMQMessage message = convertMessage(publish); message.setProducerId(producerId); @@ -739,7 +740,7 @@ public class MQTTProtocolConverter { } } - private String getClientId() { + String getClientId() { if (clientId == null) { if (connect != null && connect.clientId() != null) { clientId = connect.clientId().toString(); @@ -858,38 +859,7 @@ public class MQTTProtocolConverter { this.activeMQSubscriptionPrefetch = activeMQSubscriptionPrefetch; } - short setPacketId(MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) { - // subscription key - final StringBuilder subscriptionKey = new StringBuilder(); - subscriptionKey.append(subscription.getConsumerInfo().getDestination().getPhysicalName()) - .append(':').append(message.getJMSMessageID()); - final String keyStr = subscriptionKey.toString(); - Short packetId; - synchronized (activemqToPacketIds) { - packetId = activemqToPacketIds.get(keyStr); - if (packetId == null) { - packetId = getNextSequenceId(); - activemqToPacketIds.put(keyStr, packetId); - packetIdsToActivemq.put(packetId, keyStr); - } else { - // mark publish as duplicate! - publish.dup(true); - } - } - publish.messageId(packetId); - return packetId; - } - - void ackPacketId(short packetId) { - synchronized (activemqToPacketIds) { - final String subscriptionKey = packetIdsToActivemq.remove(packetId); - if (subscriptionKey != null) { - activemqToPacketIds.remove(subscriptionKey); - } - } - } - - short getNextSequenceId() { - return (short) messageIdGenerator.getNextSequenceId(); + public MQTTPacketIdGenerator getPacketIdGenerator() { + return packetIdGenerator; } } diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTSubscription.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTSubscription.java index 0eed8f648a..b4971bcd80 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTSubscription.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTSubscription.java @@ -57,7 +57,7 @@ class MQTTSubscription { case AT_LEAST_ONCE: case EXACTLY_ONCE: // set packet id, and optionally dup flag - protocolConverter.setPacketId(this, message, publish); + protocolConverter.getPacketIdGenerator().setPacketId(protocolConverter.getClientId(), this, message, publish); case AT_MOST_ONCE: } return publish; diff --git a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java index 3c87b340c8..d8788d362c 100644 --- a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java +++ b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java @@ -20,6 +20,9 @@ import java.net.ProtocolException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; @@ -662,10 +665,11 @@ public class MQTTTest extends AbstractMQTTTest { @Test(timeout = 60 * 1000) public void testResendMessageId() throws Exception { - addMQTTConnector(); + addMQTTConnector("trace=true"); brokerService.start(); final MQTT mqtt = createMQTTConnection("resend", false); + mqtt.setKeepAlive((short) 5); final List publishList = new ArrayList(); mqtt.setTracer(new Tracer() {