diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTSslTransportFactory.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTSslTransportFactory.java index a0d32b11a2..9935387919 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTSslTransportFactory.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTSslTransportFactory.java @@ -16,15 +16,20 @@ */ package org.apache.activemq.transport.mqtt; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; import java.util.HashMap; import java.util.Map; -import org.apache.activemq.broker.BrokerContext; +import javax.net.ssl.SSLServerSocketFactory; + import org.apache.activemq.broker.BrokerService; import org.apache.activemq.broker.BrokerServiceAware; import org.apache.activemq.transport.MutexTransport; import org.apache.activemq.transport.Transport; import org.apache.activemq.transport.tcp.SslTransportFactory; +import org.apache.activemq.transport.tcp.SslTransportServer; import org.apache.activemq.util.IntrospectionSupport; import org.apache.activemq.wireformat.WireFormat; @@ -47,6 +52,13 @@ public class MQTTSslTransportFactory extends SslTransportFactory implements Brok return super.compositeConfigure(transport, format, options); } + @Override + protected SslTransportServer createSslTransportServer(URI location, SSLServerSocketFactory serverSocketFactory) throws IOException, URISyntaxException { + final SslTransportServer server = super.createSslTransportServer(location, serverSocketFactory); + server.setAllowLinkStealing(true); + return server; + } + @SuppressWarnings("rawtypes") @Override public Transport serverConfigure(Transport transport, WireFormat format, HashMap options) throws Exception { @@ -56,7 +68,6 @@ public class MQTTSslTransportFactory extends SslTransportFactory implements Brok if (mutex != null) { mutex.setSyncOnCommand(true); } - return transport; } diff --git a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java index 8734d65d0e..8d66ae37e2 100644 --- a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java +++ b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java @@ -18,6 +18,7 @@ package org.apache.activemq.transport.mqtt; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -1083,6 +1084,56 @@ public class MQTTTest extends MQTTTestSupport { connection.disconnect(); } + @Test(timeout = 60 * 1000) + public void testDuplicateClientId() throws Exception { + // test link stealing enabled by default + stopBroker(); + startBroker(); + + final String clientId = "duplicateClient"; + MQTT mqtt = createMQTTConnection(clientId, false); + mqtt.setKeepAlive((short) 2); + BlockingConnection connection = mqtt.blockingConnection(); + connection.connect(); + final String TOPICA = "TopicA"; + connection.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); + + MQTT mqtt1 = createMQTTConnection(clientId, false); + mqtt1.setKeepAlive((short) 2); + BlockingConnection connection1 = mqtt1.blockingConnection(); + connection1.connect(); + + assertTrue("Duplicate client disconnected", connection1.isConnected()); + assertFalse("Old client still connected", connection.isConnected()); + connection1.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); + connection1.disconnect(); + + // disable link stealing + stopBroker(); + protocolConfig = "allowLinkStealing=false"; + startBroker(); + + mqtt = createMQTTConnection(clientId, false); + mqtt.setKeepAlive((short) 2); + connection = mqtt.blockingConnection(); + connection.connect(); + connection.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); + + mqtt1 = createMQTTConnection(clientId, false); + mqtt1.setKeepAlive((short) 2); + connection1 = mqtt1.blockingConnection(); + try { + connection1.connect(); + fail("Duplicate client connected"); + } catch (Exception e) { + // ignore + } + + assertTrue("Old client disconnected", connection.isConnected()); + connection.publish(TOPICA, TOPICA.getBytes(), QoS.EXACTLY_ONCE, true); + connection.disconnect(); + } + @Test(timeout = 30 * 10000) public void testJmsMapping() throws Exception { // start up jms consumer