diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java index 02e1c66965..7e88028550 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTConnectionManager.java @@ -17,7 +17,6 @@ package org.apache.activemq.artemis.core.protocol.mqtt; -import java.util.Set; import java.util.UUID; import io.netty.buffer.ByteBuf; @@ -29,7 +28,6 @@ import org.apache.activemq.artemis.core.server.ActiveMQServer; import org.apache.activemq.artemis.core.server.ServerSession; import org.apache.activemq.artemis.core.server.impl.ServerSessionImpl; import org.apache.activemq.artemis.utils.UUIDGenerator; -import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet; /** * MQTTConnectionManager is responsible for handle Connect and Disconnect packets and any resulting behaviour of these @@ -39,9 +37,6 @@ public class MQTTConnectionManager { private MQTTSession session; - //TODO Read in a list of existing client IDs from stored Sessions. - public static Set CONNECTED_CLIENTS = new ConcurrentHashSet<>(); - private MQTTLogger log = MQTTLogger.LOGGER; private boolean isWill = false; @@ -148,8 +143,12 @@ public class MQTTConnectionManager { if (session.getSessionState() != null) { session.getSessionState().setAttached(false); String clientId = session.getSessionState().getClientId(); - if (clientId != null) { - CONNECTED_CLIENTS.remove(clientId); + /** + * ensure that the connection for the client ID matches *this* connection otherwise we could remove the + * entry for the client who "stole" this client ID via [MQTT-3.1.4-2] + */ + if (clientId != null && session.getProtocolManager().isClientConnected(clientId, session.getConnection())) { + session.getProtocolManager().removeConnectedClient(clientId); } } } @@ -181,12 +180,13 @@ public class MQTTConnectionManager { // [MQTT-3.1.3-8] Return ID rejected and disconnect if clean session = false and client id is null return null; } - } else if (!CONNECTED_CLIENTS.add(clientId)) { - // ^^^ If the client ID is not unique (i.e. it has already registered) then do not accept it. + } else { + MQTTConnection connection = session.getProtocolManager().addConnectedClient(clientId, session.getConnection()); - - // [MQTT-3.1.3-9] Return ID Rejected if server rejects the client ID - return null; + if (connection != null) { + // [MQTT-3.1.4-2] If the client ID represents a client already connected to the server then the server MUST disconnect the existing client + connection.disconnect(false); + } } return clientId; } diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolManager.java index 6118b0dfed..c8832bafaa 100644 --- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolManager.java +++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTProtocolManager.java @@ -19,6 +19,8 @@ package org.apache.activemq.artemis.core.protocol.mqtt; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -52,6 +54,9 @@ class MQTTProtocolManager extends AbstractProtocolManager incomingInterceptors = new ArrayList<>(); private final List outgoingInterceptors = new ArrayList<>(); + //TODO Read in a list of existing client IDs from stored Sessions. + private Map connectedClients = new ConcurrentHashMap<>(); + MQTTProtocolManager(ActiveMQServer server, List incomingInterceptors, List outgoingInterceptors) { @@ -172,4 +177,22 @@ class MQTTProtocolManager extends AbstractProtocolManager()); - - Field connectedClients = MQTTConnectionManager.class.getDeclaredField("CONNECTED_CLIENTS"); - connectedClients.setAccessible(true); - connectedClients.set(null, new ConcurrentHashSet<>()); super.setUp(); } diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTInterceptorPropertiesTest.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTInterceptorPropertiesTest.java index 2600952944..c95a462cb2 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTInterceptorPropertiesTest.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTInterceptorPropertiesTest.java @@ -16,27 +16,25 @@ */ package org.apache.activemq.artemis.tests.integration.mqtt.imported; -import io.netty.handler.codec.mqtt.MqttFixedHeader; -import io.netty.handler.codec.mqtt.MqttMessage; -import io.netty.handler.codec.mqtt.MqttPublishMessage; -import org.apache.activemq.artemis.api.core.ActiveMQException; -import org.apache.activemq.artemis.core.protocol.mqtt.MQTTConnectionManager; -import org.apache.activemq.artemis.core.protocol.mqtt.MQTTInterceptor; -import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSession; -import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection; -import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet; -import org.apache.felix.resolver.util.ArrayMap; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ErrorCollector; - import java.lang.reflect.Field; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import io.netty.handler.codec.mqtt.MqttFixedHeader; +import io.netty.handler.codec.mqtt.MqttMessage; +import io.netty.handler.codec.mqtt.MqttPublishMessage; +import org.apache.activemq.artemis.api.core.ActiveMQException; +import org.apache.activemq.artemis.core.protocol.mqtt.MQTTInterceptor; +import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSession; +import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection; +import org.apache.felix.resolver.util.ArrayMap; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + public class MQTTInterceptorPropertiesTest extends MQTTTestSupport { @Override @@ -45,10 +43,6 @@ public class MQTTInterceptorPropertiesTest extends MQTTTestSupport { Field sessions = MQTTSession.class.getDeclaredField("SESSIONS"); sessions.setAccessible(true); sessions.set(null, new ConcurrentHashMap<>()); - - Field connectedClients = MQTTConnectionManager.class.getDeclaredField("CONNECTED_CLIENTS"); - connectedClients.setAccessible(true); - connectedClients.set(null, new ConcurrentHashSet<>()); super.setUp(); } diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTTest.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTTest.java index e3c4856039..93383840e0 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTTest.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/mqtt/imported/MQTTTest.java @@ -36,19 +36,20 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Pattern; import org.apache.activemq.artemis.api.core.RoutingType; import org.apache.activemq.artemis.api.core.SimpleString; +import org.apache.activemq.artemis.core.config.Configuration; import org.apache.activemq.artemis.core.config.CoreAddressConfiguration; import org.apache.activemq.artemis.core.config.CoreQueueConfiguration; -import org.apache.activemq.artemis.core.protocol.mqtt.MQTTConnectionManager; import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSession; import org.apache.activemq.artemis.core.protocol.mqtt.MQTTUtil; +import org.apache.activemq.artemis.core.server.ActiveMQServer; import org.apache.activemq.artemis.core.server.Queue; import org.apache.activemq.artemis.core.server.impl.AddressInfo; import org.apache.activemq.artemis.tests.util.Wait; -import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet; import org.apache.activemq.transport.amqp.client.AmqpClient; import org.apache.activemq.transport.amqp.client.AmqpConnection; import org.apache.activemq.transport.amqp.client.AmqpMessage; @@ -56,7 +57,6 @@ import org.apache.activemq.transport.amqp.client.AmqpSender; import org.apache.activemq.transport.amqp.client.AmqpSession; import org.fusesource.mqtt.client.BlockingConnection; import org.fusesource.mqtt.client.MQTT; -import org.fusesource.mqtt.client.MQTTException; import org.fusesource.mqtt.client.Message; import org.fusesource.mqtt.client.QoS; import org.fusesource.mqtt.client.Topic; @@ -85,12 +85,7 @@ public class MQTTTest extends MQTTTestSupport { Field sessions = MQTTSession.class.getDeclaredField("SESSIONS"); sessions.setAccessible(true); sessions.set(null, new ConcurrentHashMap<>()); - - Field connectedClients = MQTTConnectionManager.class.getDeclaredField("CONNECTED_CLIENTS"); - connectedClients.setAccessible(true); - connectedClients.set(null, new ConcurrentHashSet<>()); super.setUp(); - } @Test @@ -847,26 +842,14 @@ public class MQTTTest extends MQTTTestSupport { // publish non-retained message connection.publish(TOPIC, TOPIC.getBytes(), QoS.EXACTLY_ONCE, false); - Wait.waitFor(new Wait.Condition() { - @Override - public boolean isSatisfied() throws Exception { - return publishList.size() == 2; - } - }, 5000); - assertEquals(2, publishList.size()); + assertTrue(Wait.waitFor(() -> publishList.size() == 2, 5000)); connection.disconnect(); connection = mqtt.blockingConnection(); connection.connect(); - Wait.waitFor(new Wait.Condition() { - @Override - public boolean isSatisfied() throws Exception { - return publishList.size() == 4; - } - }, 5000); - assertEquals(4, publishList.size()); + assertTrue(Wait.waitFor(() -> publishList.size() == 4, 5000)); // TODO Investigate if receiving the same ID for overlapping subscriptions is actually spec compliant. // In Artemis we send a new ID for every copy of the message. @@ -1023,12 +1006,7 @@ public class MQTTTest extends MQTTTestSupport { final BlockingConnection connection = mqtt.blockingConnection(); connection.connect(); - Wait.waitFor(new Wait.Condition() { - @Override - public boolean isSatisfied() throws Exception { - return connection.isConnected(); - } - }); + Wait.waitFor(() -> connection.isConnected()); final String TOPIC = "TopicA"; final byte[] qos = connection.subscribe(new Topic[]{new Topic(TOPIC, QoS.EXACTLY_ONCE)}); @@ -1042,12 +1020,7 @@ public class MQTTTest extends MQTTTestSupport { final BlockingConnection newConnection = mqtt.blockingConnection(); newConnection.connect(); - Wait.waitFor(new Wait.Condition() { - @Override - public boolean isSatisfied() throws Exception { - return newConnection.isConnected(); - } - }); + Wait.waitFor(() -> newConnection.isConnected()); assertEquals(QoS.EXACTLY_ONCE.ordinal(), qos[0]); Message msg = newConnection.receive(1000, TimeUnit.MILLISECONDS); @@ -1069,12 +1042,7 @@ public class MQTTTest extends MQTTTestSupport { final BlockingConnection connection = mqtt.blockingConnection(); connection.connect(); - Wait.waitFor(new Wait.Condition() { - @Override - public boolean isSatisfied() throws Exception { - return connection.isConnected(); - } - }); + Wait.waitFor(() -> connection.isConnected()); MQTT mqtt2 = createMQTTConnection("2", false); BlockingConnection connection2 = mqtt2.blockingConnection(); @@ -1103,12 +1071,7 @@ public class MQTTTest extends MQTTTestSupport { final BlockingConnection connection = mqtt.blockingConnection(); connection.connect(); - Wait.waitFor(new Wait.Condition() { - @Override - public boolean isSatisfied() throws Exception { - return connection.isConnected(); - } - }); + Wait.waitFor(() -> connection.isConnected()); // kill transport connection.kill(); @@ -1281,13 +1244,7 @@ public class MQTTTest extends MQTTTestSupport { final BlockingConnection connection = mqtt.blockingConnection(); connection.connect(); - assertTrue("KeepAlive didn't work properly", Wait.waitFor(new Wait.Condition() { - - @Override - public boolean isSatisfied() throws Exception { - return connection.isConnected(); - } - })); + assertTrue("KeepAlive didn't work properly", Wait.waitFor(() -> connection.isConnected())); connection.disconnect(); } @@ -1304,13 +1261,7 @@ public class MQTTTest extends MQTTTestSupport { final BlockingConnection connection = mqtt.blockingConnection(); connection.connect(); - assertTrue("KeepAlive didn't work properly", Wait.waitFor(new Wait.Condition() { - - @Override - public boolean isSatisfied() throws Exception { - return connection.isConnected(); - } - })); + assertTrue("KeepAlive didn't work properly", Wait.waitFor(() -> connection.isConnected())); connection.disconnect(); } @@ -1355,11 +1306,8 @@ public class MQTTTest extends MQTTTestSupport { connection.disconnect(); } - @Ignore @Test(timeout = 60 * 1000) - // TODO We currently do not support link stealing. This needs to be enabled for this test to pass. public void testDuplicateClientId() throws Exception { - // test link stealing enabled by default final String clientId = "duplicateClient"; MQTT mqtt = createMQTTConnection(clientId, false); mqtt.setKeepAlive((short) 2); @@ -1373,47 +1321,41 @@ public class MQTTTest extends MQTTTestSupport { final BlockingConnection connection1 = mqtt1.blockingConnection(); connection1.connect(); - assertTrue("Duplicate client disconnected", Wait.waitFor(new Wait.Condition() { - @Override - public boolean isSatisfied() throws Exception { - return connection1.isConnected(); - } - })); + assertTrue("Duplicate client disconnected", Wait.waitFor(() -> connection1.isConnected())); - assertTrue("Old client still connected", Wait.waitFor(new Wait.Condition() { - @Override - public boolean isSatisfied() throws Exception { - return !connection.isConnected(); - } - })); + assertTrue("Old client still connected", Wait.waitFor(() -> !connection.isConnected())); connection1.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); connection1.disconnect(); + } - // disable link stealing - stopBroker(); - protocolConfig = "allowLinkStealing=false"; - startBroker(); + @Test(timeout = 60 * 1000) + public void testRepeatedLinkStealing() throws Exception { + final String clientId = "duplicateClient"; + final AtomicReference oldConnection = new AtomicReference<>(); + final String TOPICA = "TopicA"; - mqtt = createMQTTConnection(clientId, false); - mqtt.setKeepAlive((short) 2); - final BlockingConnection connection2 = mqtt.blockingConnection(); - connection2.connect(); - connection2.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); + for (int i = 1; i <= 10; ++i) { - mqtt1 = createMQTTConnection(clientId, false); - mqtt1.setKeepAlive((short) 2); - final BlockingConnection connection3 = mqtt1.blockingConnection(); - try { - connection3.connect(); - fail("Duplicate client connected"); - } catch (Exception e) { - // ignore + LOG.info("Creating MQTT Connection {}", i); + + MQTT mqtt = createMQTTConnection(clientId, false); + mqtt.setKeepAlive((short) 2); + final BlockingConnection connection = mqtt.blockingConnection(); + connection.connect(); + connection.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); + + assertTrue("Client connect failed for attempt: " + i, Wait.waitFor(() -> connection.isConnected(), 3000, 200)); + + if (oldConnection.get() != null) { + assertTrue("Old client still connected on attempt: " + i, Wait.waitFor(() -> !oldConnection.get().isConnected(), 3000, 200)); + } + + oldConnection.set(connection); } - assertTrue("Old client disconnected", connection2.isConnected()); - connection2.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); - connection2.disconnect(); + oldConnection.get().publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); + oldConnection.get().disconnect(); } @Test(timeout = 30 * 10000) @@ -1569,13 +1511,7 @@ public class MQTTTest extends MQTTTestSupport { final BlockingConnection connection = mqtt.blockingConnection(); connection.connect(); - assertTrue("KeepAlive didn't work properly", Wait.waitFor(new Wait.Condition() { - - @Override - public boolean isSatisfied() throws Exception { - return connection.isConnected(); - } - })); + assertTrue("KeepAlive didn't work properly", Wait.waitFor(() -> connection.isConnected())); } @Test(timeout = 60 * 1000) @@ -1767,13 +1703,7 @@ public class MQTTTest extends MQTTTestSupport { mqtt.setKeepAlive((short) 2); final BlockingConnection connection = mqtt.blockingConnection(); connection.connect(); - assertTrue("KeepAlive didn't work properly", Wait.waitFor(new Wait.Condition() { - - @Override - public boolean isSatisfied() throws Exception { - return connection.isConnected(); - } - })); + assertTrue("KeepAlive didn't work properly", Wait.waitFor(() -> connection.isConnected())); connection.disconnect(); } @@ -1974,20 +1904,47 @@ public class MQTTTest extends MQTTTestSupport { } @Test - public void testDuplicateIDReturnsError() throws Exception { - String clientId = "clientId"; - MQTT mqtt = createMQTTConnection(); - mqtt.setClientId(clientId); - mqtt.blockingConnection().connect(); + public void testDoubleBroker() throws Exception { + /* + * Start two embedded server instances for MQTT and connect to them + * with the same MQTT client id. As those are two different instances + * connecting to them with the same client ID must succeed. + */ + + final int port1 = 1884; + final int port2 = 1885; + + final Configuration cfg1 = createDefaultConfig(1, false); + cfg1.addAcceptorConfiguration("mqtt1", "tcp://localhost:" + port1 + "?protocols=MQTT"); + + final Configuration cfg2 = createDefaultConfig(2, false); + cfg2.addAcceptorConfiguration("mqtt2", "tcp://localhost:" + port2 + "?protocols=MQTT"); + + final ActiveMQServer server1 = createServer(cfg1); + server1.start(); + final ActiveMQServer server2 = createServer(cfg2); + server2.start(); + + final String clientId = "client1"; + final MQTT mqtt1 = createMQTTConnection(clientId, true); + final MQTT mqtt2 = createMQTTConnection(clientId, true); + + mqtt1.setHost("localhost", port1); + mqtt2.setHost("localhost", port2); + + final BlockingConnection connection1 = mqtt1.blockingConnection(); + final BlockingConnection connection2 = mqtt2.blockingConnection(); - MQTTException e = null; try { - MQTT mqtt2 = createMQTTConnection(); - mqtt2.setClientId(clientId); - mqtt2.blockingConnection().connect(); - } catch (MQTTException mqttE) { - e = mqttE; + connection1.connect(); + connection2.connect(); + } catch (Exception e) { + fail("Connections should have worked."); + } finally { + if (connection1.isConnected()) + connection1.disconnect(); + if (connection2.isConnected()) + connection2.disconnect(); } - assertTrue(e.getMessage().contains("CONNECTION_REFUSED_IDENTIFIER_REJECTED")); } } diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/plugin/MqttPluginTest.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/plugin/MqttPluginTest.java index 660df34f3e..2365ae5c43 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/plugin/MqttPluginTest.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/integration/plugin/MqttPluginTest.java @@ -16,20 +16,14 @@ */ package org.apache.activemq.artemis.tests.integration.plugin; -import java.lang.reflect.Field; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import org.apache.activemq.artemis.core.protocol.mqtt.MQTTConnectionManager; -import org.apache.activemq.artemis.core.protocol.mqtt.MQTTSession; import org.apache.activemq.artemis.tests.integration.mqtt.imported.MQTTClientProvider; import org.apache.activemq.artemis.tests.integration.mqtt.imported.MQTTTestSupport; -import org.apache.activemq.artemis.utils.collections.ConcurrentHashSet; -import org.junit.Before; import org.junit.Test; import static org.apache.activemq.artemis.tests.integration.plugin.MethodCalledVerifier.AFTER_CLOSE_CONSUMER; @@ -61,20 +55,6 @@ public class MqttPluginTest extends MQTTTestSupport { private final Map methodCalls = new HashMap<>(); private final MethodCalledVerifier verifier = new MethodCalledVerifier(methodCalls); - @Override - @Before - public void setUp() throws Exception { - Field sessions = MQTTSession.class.getDeclaredField("SESSIONS"); - sessions.setAccessible(true); - sessions.set(null, new ConcurrentHashMap<>()); - - Field connectedClients = MQTTConnectionManager.class.getDeclaredField("CONNECTED_CLIENTS"); - connectedClients.setAccessible(true); - connectedClients.set(null, new ConcurrentHashSet<>()); - super.setUp(); - - } - @Override public void configureBroker() throws Exception { super.configureBroker(); diff --git a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/util/Wait.java b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/util/Wait.java index 795a47844b..2f3772af82 100644 --- a/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/util/Wait.java +++ b/tests/integration-tests/src/test/java/org/apache/activemq/artemis/tests/util/Wait.java @@ -40,10 +40,10 @@ public class Wait { } public static boolean waitFor(final Condition condition, - final long duration, + final long durationMillis, final long sleepMillis) throws Exception { - final long expiry = System.currentTimeMillis() + duration; + final long expiry = System.currentTimeMillis() + durationMillis; boolean conditionSatisified = condition.isSatisfied(); while (!conditionSatisified && System.currentTimeMillis() < expiry) { TimeUnit.MILLISECONDS.sleep(sleepMillis);