Apply patch3 to fix packet id generation
This commit is contained in:
Timothy Bish 2014-03-21 17:09:52 -04:00
parent ff409b6f2c
commit afddc1a832
4 changed files with 194 additions and 44 deletions

View File

@ -0,0 +1,176 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.mqtt;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.activemq.Service;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.command.ActiveMQMessage;
import org.apache.activemq.util.LRUCache;
import org.apache.activemq.util.ServiceStopper;
import org.apache.activemq.util.ServiceSupport;
import org.fusesource.mqtt.codec.PUBLISH;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Manages PUBLISH packet ids for clients.
*
* @author Dhiraj Bokde
*/
public class MQTTPacketIdGenerator extends ServiceSupport {
private static final Logger LOG = LoggerFactory.getLogger(MQTTPacketIdGenerator.class);
private static final Object LOCK = new Object();
Map<String, PacketIdMaps> clientIdMap = new ConcurrentHashMap<String, PacketIdMaps>();
private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator();
private MQTTPacketIdGenerator() {
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
synchronized (this) {
clientIdMap = new ConcurrentHashMap<String, PacketIdMaps>();
}
}
@Override
protected void doStart() throws Exception {
}
public void startClientSession(String clientId) {
if (!clientIdMap.containsKey(clientId)) {
clientIdMap.put(clientId, new PacketIdMaps());
}
}
public boolean stopClientSession(String clientId) {
return clientIdMap.remove(clientId) != null;
}
public short setPacketId(String clientId, MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) {
final PacketIdMaps idMaps = clientIdMap.get(clientId);
if (idMaps == null) {
// maybe its a cleansession=true client id, use session less message id
final short id = messageIdGenerator.getNextSequenceId();
publish.messageId(id);
return id;
} else {
return idMaps.setPacketId(subscription, message, publish);
}
}
public void ackPacketId(String clientId, short packetId) {
final PacketIdMaps idMaps = clientIdMap.get(clientId);
if (idMaps != null) {
idMaps.ackPacketId(packetId);
}
}
public short getNextSequenceId(String clientId) {
final PacketIdMaps idMaps = clientIdMap.get(clientId);
return idMaps != null ? idMaps.getNextSequenceId(): messageIdGenerator.getNextSequenceId();
}
public static MQTTPacketIdGenerator getMQTTPacketIdGenerator(BrokerService broker) {
MQTTPacketIdGenerator result = null;
if (broker != null) {
synchronized (LOCK) {
Service[] services = broker.getServices();
if (services != null) {
for (Service service : services) {
if (service instanceof MQTTPacketIdGenerator) {
return (MQTTPacketIdGenerator) service;
}
}
}
result = new MQTTPacketIdGenerator();
broker.addService(result);
if (broker.isStarted()) {
try {
result.start();
} catch (Exception e) {
LOG.warn("Couldn't start MQTTPacketIdGenerator");
}
}
}
}
return result;
}
private class PacketIdMaps {
private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator();
final Map<String, Short> activemqToPacketIds = new LRUCache<String, Short>(MQTTProtocolConverter.DEFAULT_CACHE_SIZE);
final Map<Short, String> packetIdsToActivemq = new LRUCache<Short, String>(MQTTProtocolConverter.DEFAULT_CACHE_SIZE);
short setPacketId(MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) {
// subscription key
final StringBuilder subscriptionKey = new StringBuilder();
subscriptionKey.append(subscription.getConsumerInfo().getDestination().getPhysicalName())
.append(':').append(message.getJMSMessageID());
final String keyStr = subscriptionKey.toString();
Short packetId;
synchronized (activemqToPacketIds) {
packetId = activemqToPacketIds.get(keyStr);
if (packetId == null) {
packetId = getNextSequenceId();
activemqToPacketIds.put(keyStr, packetId);
packetIdsToActivemq.put(packetId, keyStr);
} else {
// mark publish as duplicate!
publish.dup(true);
}
}
publish.messageId(packetId);
return packetId;
}
void ackPacketId(short packetId) {
synchronized (activemqToPacketIds) {
final String subscriptionKey = packetIdsToActivemq.remove(packetId);
if (subscriptionKey != null) {
activemqToPacketIds.remove(subscriptionKey);
}
}
}
short getNextSequenceId() {
return messageIdGenerator.getNextSequenceId();
}
}
private class NonZeroSequenceGenerator {
private short lastSequenceId;
public synchronized short getNextSequenceId() {
final short val = ++lastSequenceId;
return val != 0 ? val : ++lastSequenceId;
}
}
}

View File

@ -51,13 +51,12 @@ public class MQTTProtocolConverter {
private static final IdGenerator CONNECTION_ID_GENERATOR = new IdGenerator();
private static final MQTTFrame PING_RESP_FRAME = new PINGRESP().encode();
private static final double MQTT_KEEP_ALIVE_GRACE_PERIOD= 0.5;
private static final int DEFAULT_CACHE_SIZE = 5000;
static final int DEFAULT_CACHE_SIZE = 5000;
private static final byte SUBSCRIBE_ERROR = (byte) 0x80;
private final ConnectionId connectionId = new ConnectionId(CONNECTION_ID_GENERATOR.generateId());
private final SessionId sessionId = new SessionId(connectionId, -1);
private final ProducerId producerId = new ProducerId(sessionId, 1);
private final LongSequenceGenerator messageIdGenerator = new LongSequenceGenerator();
private final LongSequenceGenerator publisherIdGenerator = new LongSequenceGenerator();
private final LongSequenceGenerator consumerIdGenerator = new LongSequenceGenerator();
@ -68,8 +67,6 @@ public class MQTTProtocolConverter {
private final Map<Destination, UTF8Buffer> mqttTopicMap = new LRUCache<Destination, UTF8Buffer>(DEFAULT_CACHE_SIZE);
private final Map<Short, MessageAck> consumerAcks = new LRUCache<Short, MessageAck>(DEFAULT_CACHE_SIZE);
private final Map<Short, PUBREC> publisherRecs = new LRUCache<Short, PUBREC>(DEFAULT_CACHE_SIZE);
private final Map<String, Short> activemqToPacketIds = new LRUCache<String, Short>(DEFAULT_CACHE_SIZE);
private final Map<Short, String> packetIdsToActivemq = new LRUCache<Short, String>(DEFAULT_CACHE_SIZE);
private final MQTTTransport mqttTransport;
private final BrokerService brokerService;
@ -84,11 +81,13 @@ public class MQTTProtocolConverter {
private int activeMQSubscriptionPrefetch=1;
private final String QOS_PROPERTY_NAME = "QoSPropertyName";
private final MQTTRetainedMessages retainedMessages;
private final MQTTPacketIdGenerator packetIdGenerator;
public MQTTProtocolConverter(MQTTTransport mqttTransport, BrokerService brokerService) {
this.mqttTransport = mqttTransport;
this.brokerService = brokerService;
this.retainedMessages = MQTTRetainedMessages.getMQTTRetainedMessages(brokerService);
this.packetIdGenerator = MQTTPacketIdGenerator.getMQTTPacketIdGenerator(brokerService);
this.defaultKeepAlive = 0;
}
@ -276,8 +275,10 @@ public class MQTTProtocolConverter {
List<SubscriptionInfo> subs = PersistenceAdapterSupport.listSubscriptions(brokerService.getPersistenceAdapter(), connectionInfo.getClientId());
if( connect.cleanSession() ) {
packetIdGenerator.stopClientSession(getClientId());
deleteDurableSubs(subs);
} else {
packetIdGenerator.startClientSession(getClientId());
restoreDurableSubs(subs);
}
}
@ -363,7 +364,7 @@ public class MQTTProtocolConverter {
switch (retainedCopy.qos()) {
case AT_LEAST_ONCE:
case EXACTLY_ONCE:
retainedCopy.messageId(getNextSequenceId());
retainedCopy.messageId(packetIdGenerator.getNextSequenceId(getClientId()));
case AT_MOST_ONCE:
}
getMQTTTransport().sendToMQTT(retainedCopy.encode());
@ -517,7 +518,7 @@ public class MQTTProtocolConverter {
void onMQTTPubAck(PUBACK command) {
short messageId = command.messageId();
ackPacketId(messageId);
packetIdGenerator.ackPacketId(getClientId(), messageId);
MessageAck ack;
synchronized (consumerAcks) {
ack = consumerAcks.remove(messageId);
@ -549,7 +550,7 @@ public class MQTTProtocolConverter {
void onMQTTPubComp(PUBCOMP command) {
short messageId = command.messageId();
ackPacketId(messageId);
packetIdGenerator.ackPacketId(getClientId(), messageId);
MessageAck ack;
synchronized (consumerAcks) {
ack = consumerAcks.remove(messageId);
@ -662,7 +663,7 @@ public class MQTTProtocolConverter {
PUBLISH publish = new PUBLISH();
publish.topicName(connect.willTopic());
publish.qos(connect.willQos());
publish.messageId(getNextSequenceId());
publish.messageId(packetIdGenerator.getNextSequenceId(getClientId()));
publish.payload(connect.willMessage());
ActiveMQMessage message = convertMessage(publish);
message.setProducerId(producerId);
@ -739,7 +740,7 @@ public class MQTTProtocolConverter {
}
}
private String getClientId() {
String getClientId() {
if (clientId == null) {
if (connect != null && connect.clientId() != null) {
clientId = connect.clientId().toString();
@ -858,38 +859,7 @@ public class MQTTProtocolConverter {
this.activeMQSubscriptionPrefetch = activeMQSubscriptionPrefetch;
}
short setPacketId(MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) {
// subscription key
final StringBuilder subscriptionKey = new StringBuilder();
subscriptionKey.append(subscription.getConsumerInfo().getDestination().getPhysicalName())
.append(':').append(message.getJMSMessageID());
final String keyStr = subscriptionKey.toString();
Short packetId;
synchronized (activemqToPacketIds) {
packetId = activemqToPacketIds.get(keyStr);
if (packetId == null) {
packetId = getNextSequenceId();
activemqToPacketIds.put(keyStr, packetId);
packetIdsToActivemq.put(packetId, keyStr);
} else {
// mark publish as duplicate!
publish.dup(true);
}
}
publish.messageId(packetId);
return packetId;
}
void ackPacketId(short packetId) {
synchronized (activemqToPacketIds) {
final String subscriptionKey = packetIdsToActivemq.remove(packetId);
if (subscriptionKey != null) {
activemqToPacketIds.remove(subscriptionKey);
}
}
}
short getNextSequenceId() {
return (short) messageIdGenerator.getNextSequenceId();
public MQTTPacketIdGenerator getPacketIdGenerator() {
return packetIdGenerator;
}
}

View File

@ -57,7 +57,7 @@ class MQTTSubscription {
case AT_LEAST_ONCE:
case EXACTLY_ONCE:
// set packet id, and optionally dup flag
protocolConverter.setPacketId(this, message, publish);
protocolConverter.getPacketIdGenerator().setPacketId(protocolConverter.getClientId(), this, message, publish);
case AT_MOST_ONCE:
}
return publish;

View File

@ -20,6 +20,9 @@ import java.net.ProtocolException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
@ -662,10 +665,11 @@ public class MQTTTest extends AbstractMQTTTest {
@Test(timeout = 60 * 1000)
public void testResendMessageId() throws Exception {
addMQTTConnector();
addMQTTConnector("trace=true");
brokerService.start();
final MQTT mqtt = createMQTTConnection("resend", false);
mqtt.setKeepAlive((short) 5);
final List<PUBLISH> publishList = new ArrayList<PUBLISH>();
mqtt.setTracer(new Tracer() {