diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java index ff6ee4335a..51475f9bfe 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java @@ -654,30 +654,31 @@ public class MQTTProtocolConverter { return mqttTransport; } - boolean willSent = false; + AtomicBoolean transportErrorHandled = new AtomicBoolean(false); public void onTransportError() { - if (connect != null) { - if (connected.get()) { - if (connect.willTopic() != null && connect.willMessage() != null && !willSent) { - willSent = true; - try { - PUBLISH publish = new PUBLISH(); - publish.topicName(connect.willTopic()); - publish.qos(connect.willQos()); - publish.messageId(packetIdGenerator.getNextSequenceId(getClientId())); - publish.payload(connect.willMessage()); - publish.retain(connect.willRetain()); - ActiveMQMessage message = convertMessage(publish); - message.setProducerId(producerId); - message.onSend(); + if (transportErrorHandled.compareAndSet(false, true)) { + if (connect != null) { + if (connected.get()) { + if (connect.willTopic() != null && connect.willMessage() != null) { + try { + PUBLISH publish = new PUBLISH(); + publish.topicName(connect.willTopic()); + publish.qos(connect.willQos()); + publish.messageId(packetIdGenerator.getNextSequenceId(getClientId())); + publish.payload(connect.willMessage()); + publish.retain(connect.willRetain()); + ActiveMQMessage message = convertMessage(publish); + message.setProducerId(producerId); + message.onSend(); - sendToActiveMQ(message, null); - } catch (Exception e) { - LOG.warn("Failed to publish Will Message " + connect.willMessage()); + sendToActiveMQ(message, null); + } catch (Exception e) { + LOG.warn("Failed to publish Will Message " + connect.willMessage()); + } } + // remove connection info + sendToActiveMQ(connectionInfo.createRemoveCommand(), null); } - // remove connection info - sendToActiveMQ(connectionInfo.createRemoveCommand(), null); } } } @@ -887,4 +888,9 @@ public class MQTTProtocolConverter { } return subsciptionStrategy; } + + // for testing + public void setSubsciptionStrategy(MQTTSubscriptionStrategy subsciptionStrategy) { + this.subsciptionStrategy = subsciptionStrategy; + } } diff --git a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverterTest.java b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverterTest.java index bfe3149d5c..c445f924ac 100644 --- a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverterTest.java +++ b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverterTest.java @@ -18,10 +18,20 @@ package org.apache.activemq.transport.mqtt; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.times; import java.io.IOException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import org.apache.activemq.broker.BrokerService; +import org.apache.activemq.command.Command; +import org.apache.activemq.command.ConnectionInfo; +import org.apache.activemq.command.ProducerInfo; +import org.apache.activemq.command.RemoveInfo; +import org.apache.activemq.command.Response; +import org.apache.activemq.transport.mqtt.strategy.MQTTSubscriptionStrategy; import org.fusesource.mqtt.codec.CONNACK; import org.fusesource.mqtt.codec.CONNECT; import org.fusesource.mqtt.codec.MQTTFrame; @@ -76,4 +86,50 @@ public class MQTTProtocolConverterTest { CONNACK connAck = new CONNACK().decode(response); assertEquals(CONNACK.Code.CONNECTION_REFUSED_UNACCEPTED_PROTOCOL_VERSION, connAck.code()); } + + @Test + public void testConcurrentOnTransportError() throws Exception { + MQTTProtocolConverter converter = new MQTTProtocolConverter(transport, broker); + converter.setSubsciptionStrategy(Mockito.mock(MQTTSubscriptionStrategy.class)); + + CONNECT connect = Mockito.mock(CONNECT.class); + + Mockito.when(connect.version()).thenReturn(3); + Mockito.when(connect.cleanSession()).thenReturn(true); + + converter.onMQTTConnect(connect); + + ArgumentCaptor connectionInfoArgumentCaptor = ArgumentCaptor.forClass(ConnectionInfo.class); + Mockito.verify(transport).sendToActiveMQ(connectionInfoArgumentCaptor.capture()); + + ConnectionInfo connectInfo = connectionInfoArgumentCaptor.getValue(); + Response ok = new Response(); + ok.setCorrelationId(connectInfo.getCommandId()); + converter.onActiveMQCommand(ok); + + ArgumentCaptor producerInfoArgumentCaptor = ArgumentCaptor.forClass(Command.class); + Mockito.verify(transport, times(3)).sendToActiveMQ(producerInfoArgumentCaptor.capture()); + + ProducerInfo producerInfo = (ProducerInfo) producerInfoArgumentCaptor.getValue(); + ok = new Response(); + ok.setCorrelationId(producerInfo.getCommandId()); + converter.onActiveMQCommand(ok); + + ExecutorService executorService = Executors.newCachedThreadPool(); + for (int i=0; i<10; i++) { + executorService.submit(new Runnable() { + @Override + public void run() { + converter.onTransportError(); + } + }); + } + + executorService.shutdown(); + executorService.awaitTermination(10, TimeUnit.SECONDS); + + ArgumentCaptor removeInfo = ArgumentCaptor.forClass(RemoveInfo.class); + Mockito.verify(transport, times(4)).sendToActiveMQ(removeInfo.capture()); + + } }