ARTEMIS-1218 implement MQTT link stealing

This commit is contained in:
Justin Bertram 2017-09-15 10:59:57 -05:00 committed by Clebert Suconic
parent 144dbadcb5
commit dac625179a
3 changed files with 66 additions and 54 deletions

View File

@ -143,8 +143,12 @@ public class MQTTConnectionManager {
if (session.getSessionState() != null) { if (session.getSessionState() != null) {
session.getSessionState().setAttached(false); session.getSessionState().setAttached(false);
String clientId = session.getSessionState().getClientId(); String clientId = session.getSessionState().getClientId();
if (clientId != null) { /**
session.getProtocolManager().getConnectedClients().remove(clientId); * ensure that the connection for the client ID matches *this* connection otherwise we could remove the
* entry for the client who "stole" this client ID via [MQTT-3.1.4-2]
*/
if (clientId != null && session.getProtocolManager().isClientConnected(clientId, session.getConnection())) {
session.getProtocolManager().removeConnectedClient(clientId);
} }
} }
} }
@ -176,12 +180,13 @@ public class MQTTConnectionManager {
// [MQTT-3.1.3-8] Return ID rejected and disconnect if clean session = false and client id is null // [MQTT-3.1.3-8] Return ID rejected and disconnect if clean session = false and client id is null
return null; return null;
} }
} else if (!session.getProtocolManager().getConnectedClients().add(clientId)) { } else {
// ^^^ If the client ID is not unique (i.e. it has already registered) then do not accept it. MQTTConnection connection = session.getProtocolManager().addConnectedClient(clientId, session.getConnection());
if (connection != null) {
// [MQTT-3.1.3-9] Return ID Rejected if server rejects the client ID // [MQTT-3.1.4-2] If the client ID represents a client already connected to the server then the server MUST disconnect the existing client
return null; connection.disconnect(false);
}
} }
return clientId; return clientId;
} }

View File

@ -19,7 +19,8 @@ package org.apache.activemq.artemis.core.protocol.mqtt;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
@ -39,7 +40,6 @@ import org.apache.activemq.artemis.spi.core.protocol.ProtocolManagerFactory;
import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection; import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection;
import org.apache.activemq.artemis.spi.core.remoting.Acceptor; import org.apache.activemq.artemis.spi.core.remoting.Acceptor;
import org.apache.activemq.artemis.spi.core.remoting.Connection; import org.apache.activemq.artemis.spi.core.remoting.Connection;
import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet;
/** /**
* MQTTProtocolManager * MQTTProtocolManager
@ -55,7 +55,7 @@ class MQTTProtocolManager extends AbstractProtocolManager<MqttMessage, MQTTInter
private final List<MQTTInterceptor> outgoingInterceptors = new ArrayList<>(); private final List<MQTTInterceptor> outgoingInterceptors = new ArrayList<>();
//TODO Read in a list of existing client IDs from stored Sessions. //TODO Read in a list of existing client IDs from stored Sessions.
private Set<String> connectedClients = new ConcurrentHashSet<>(); private Map<String, MQTTConnection> connectedClients = new ConcurrentHashMap<>();
MQTTProtocolManager(ActiveMQServer server, MQTTProtocolManager(ActiveMQServer server,
List<BaseInterceptor> incomingInterceptors, List<BaseInterceptor> incomingInterceptors,
@ -178,7 +178,21 @@ class MQTTProtocolManager extends AbstractProtocolManager<MqttMessage, MQTTInter
super.invokeInterceptors(this.outgoingInterceptors, mqttMessage, connection); super.invokeInterceptors(this.outgoingInterceptors, mqttMessage, connection);
} }
public Set<String> getConnectedClients() { public boolean isClientConnected(String clientId, MQTTConnection connection) {
return connectedClients; return connectedClients.get(clientId).equals(connection);
}
public void removeConnectedClient(String clientId) {
connectedClients.remove(clientId);
}
/**
* @param clientId
* @param connection
* @return the {@code MQTTConnection} that the added connection replaced or null if there was no previous entry for
* the {@code clientId}
*/
public MQTTConnection addConnectedClient(String clientId, MQTTConnection connection) {
return connectedClients.put(clientId, connection);
} }
} }

View File

@ -36,6 +36,7 @@ import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import org.apache.activemq.artemis.api.core.RoutingType; import org.apache.activemq.artemis.api.core.RoutingType;
@ -56,7 +57,6 @@ import org.apache.activemq.transport.amqp.client.AmqpSender;
import org.apache.activemq.transport.amqp.client.AmqpSession; import org.apache.activemq.transport.amqp.client.AmqpSession;
import org.fusesource.mqtt.client.BlockingConnection; import org.fusesource.mqtt.client.BlockingConnection;
import org.fusesource.mqtt.client.MQTT; import org.fusesource.mqtt.client.MQTT;
import org.fusesource.mqtt.client.MQTTException;
import org.fusesource.mqtt.client.Message; import org.fusesource.mqtt.client.Message;
import org.fusesource.mqtt.client.QoS; import org.fusesource.mqtt.client.QoS;
import org.fusesource.mqtt.client.Topic; import org.fusesource.mqtt.client.Topic;
@ -1350,11 +1350,8 @@ public class MQTTTest extends MQTTTestSupport {
connection.disconnect(); connection.disconnect();
} }
@Ignore
@Test(timeout = 60 * 1000) @Test(timeout = 60 * 1000)
// TODO We currently do not support link stealing. This needs to be enabled for this test to pass.
public void testDuplicateClientId() throws Exception { public void testDuplicateClientId() throws Exception {
// test link stealing enabled by default
final String clientId = "duplicateClient"; final String clientId = "duplicateClient";
MQTT mqtt = createMQTTConnection(clientId, false); MQTT mqtt = createMQTTConnection(clientId, false);
mqtt.setKeepAlive((short) 2); mqtt.setKeepAlive((short) 2);
@ -1384,31 +1381,45 @@ public class MQTTTest extends MQTTTestSupport {
connection1.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); connection1.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true);
connection1.disconnect(); connection1.disconnect();
}
// disable link stealing @Test(timeout = 60 * 1000)
stopBroker(); public void testRepeatedLinkStealing() throws Exception {
protocolConfig = "allowLinkStealing=false"; final String clientId = "duplicateClient";
startBroker(); final AtomicReference<BlockingConnection> oldConnection = new AtomicReference<>();
final String TOPICA = "TopicA";
mqtt = createMQTTConnection(clientId, false); for (int i = 1; i <= 10; ++i) {
mqtt.setKeepAlive((short) 2);
final BlockingConnection connection2 = mqtt.blockingConnection();
connection2.connect();
connection2.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true);
mqtt1 = createMQTTConnection(clientId, false); LOG.info("Creating MQTT Connection {}", i);
mqtt1.setKeepAlive((short) 2);
final BlockingConnection connection3 = mqtt1.blockingConnection(); MQTT mqtt = createMQTTConnection(clientId, false);
try { mqtt.setKeepAlive((short) 2);
connection3.connect(); final BlockingConnection connection = mqtt.blockingConnection();
fail("Duplicate client connected"); connection.connect();
} catch (Exception e) { connection.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true);
// ignore
assertTrue("Client connect failed for attempt: " + i, Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisfied() throws Exception {
return connection.isConnected();
}
}, TimeUnit.SECONDS.toMillis(3), TimeUnit.MILLISECONDS.toMillis(200)));
if (oldConnection.get() != null) {
assertTrue("Old client still connected on attempt: " + i, Wait.waitFor(new Wait.Condition() {
@Override
public boolean isSatisfied() throws Exception {
return !oldConnection.get().isConnected();
}
}, TimeUnit.SECONDS.toMillis(3), TimeUnit.MILLISECONDS.toMillis(200)));
}
oldConnection.set(connection);
} }
assertTrue("Old client disconnected", connection2.isConnected()); oldConnection.get().publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true);
connection2.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); oldConnection.get().disconnect();
connection2.disconnect();
} }
@Test(timeout = 30 * 10000) @Test(timeout = 30 * 10000)
@ -1968,24 +1979,6 @@ public class MQTTTest extends MQTTTestSupport {
} }
@Test
public void testDuplicateIDReturnsError() throws Exception {
String clientId = "clientId";
MQTT mqtt = createMQTTConnection();
mqtt.setClientId(clientId);
mqtt.blockingConnection().connect();
MQTTException e = null;
try {
MQTT mqtt2 = createMQTTConnection();
mqtt2.setClientId(clientId);
mqtt2.blockingConnection().connect();
} catch (MQTTException mqttE) {
e = mqttE;
}
assertTrue(e.getMessage().contains("CONNECTION_REFUSED_IDENTIFIER_REJECTED"));
}
@Test @Test
public void testDoubleBroker() throws Exception { public void testDoubleBroker() throws Exception {
/* /*