diff --git a/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/main/java/org/apache/nifi/processors/mqtt/ConsumeMQTT.java b/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/main/java/org/apache/nifi/processors/mqtt/ConsumeMQTT.java index ff70f7fc0b..659dd2f13d 100644 --- a/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/main/java/org/apache/nifi/processors/mqtt/ConsumeMQTT.java +++ b/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/main/java/org/apache/nifi/processors/mqtt/ConsumeMQTT.java @@ -32,6 +32,7 @@ import org.apache.nifi.annotation.lifecycle.OnUnscheduled; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.Tags; +import org.apache.nifi.flowfile.attributes.CoreAttributes; import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessorInitializationContext; import org.apache.nifi.processor.Relationship; @@ -289,8 +290,15 @@ public class ConsumeMQTT extends AbstractMQTTProcessor { String transitUri = new StringBuilder(broker).append(mqttMessage.getTopic()).toString(); session.getProvenanceReporter().receive(messageFlowfile, transitUri); session.transfer(messageFlowfile, REL_MESSAGE); - mqttQueue.remove(mqttMessage); session.commit(); + if (!mqttQueue.remove(mqttMessage) && logger.isWarnEnabled()) { + logger.warn(new StringBuilder("FlowFile ") + .append(messageFlowfile.getAttribute(CoreAttributes.UUID.key())) + .append(" for Mqtt message ") + .append(mqttMessage) + .append(" had already been removed from queue, possible duplication of flow files") + .toString()); + } } } diff --git a/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/test/java/org/apache/nifi/processors/mqtt/TestConsumeMQTT.java b/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/test/java/org/apache/nifi/processors/mqtt/TestConsumeMQTT.java index 58c37e505a..144cd63614 100644 --- a/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/test/java/org/apache/nifi/processors/mqtt/TestConsumeMQTT.java +++ b/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/test/java/org/apache/nifi/processors/mqtt/TestConsumeMQTT.java @@ -18,6 +18,8 @@ package org.apache.nifi.processors.mqtt; import io.moquette.proto.messages.PublishMessage; +import org.apache.nifi.processor.ProcessSession; +import org.apache.nifi.processors.mqtt.common.MQTTQueueMessage; import org.apache.nifi.processors.mqtt.common.MqttTestClient; import org.apache.nifi.processors.mqtt.common.TestConsumeMqttCommon; import org.apache.nifi.util.TestRunners; @@ -26,17 +28,23 @@ import org.eclipse.paho.client.mqttv3.MqttException; import org.eclipse.paho.client.mqttv3.MqttMessage; import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence; import org.junit.After; -import org.junit.Assert; import org.junit.Before; +import org.junit.Test; import java.io.File; import java.io.FilenameFilter; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Proxy; +import java.util.concurrent.BlockingQueue; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class TestConsumeMQTT extends TestConsumeMqttCommon { - - public MqttTestClient mqttTestClient; public class UnitTestableConsumeMqtt extends ConsumeMQTT { @@ -65,6 +73,35 @@ public class TestConsumeMQTT extends TestConsumeMqttCommon { testRunner.setProperty(ConsumeMQTT.PROP_MAX_QUEUE_SIZE, "100"); } + /** + * If the session.commit() fails, we should not remove the unprocessed message + */ + @Test + public void testMessageNotConsumedOnCommitFail() throws NoSuchFieldException, IllegalAccessException, NoSuchMethodException, InvocationTargetException { + testRunner.run(1, false); + ConsumeMQTT processor = (ConsumeMQTT) testRunner.getProcessor(); + MQTTQueueMessage mock = mock(MQTTQueueMessage.class); + when(mock.getPayload()).thenReturn(new byte[0]); + when(mock.getTopic()).thenReturn("testTopic"); + BlockingQueue mqttQueue = getMqttQueue(processor); + mqttQueue.add(mock); + try { + ProcessSession session = testRunner.getProcessSessionFactory().createSession(); + transferQueue(processor, + (ProcessSession) Proxy.newProxyInstance(getClass().getClassLoader(), new Class[] { ProcessSession.class }, (proxy, method, args) -> { + if (method.getName().equals("commit")) { + throw new RuntimeException(); + } else { + return method.invoke(session, args); + } + })); + fail("Expected runtime exception"); + } catch (InvocationTargetException e) { + assertTrue("Expected generic runtime exception, not " + e, e.getCause() instanceof RuntimeException); + } + assertTrue("Expected mqttQueue to contain uncommitted message.", mqttQueue.contains(mock)); + } + @After public void tearDown() throws Exception { if (MQTT_server != null) { @@ -95,7 +132,7 @@ public class TestConsumeMQTT extends TestConsumeMqttCommon { try { mqttTestClient.publish(publishMessage.getTopicName(), mqttMessage); } catch (MqttException e) { - Assert.fail("Should never get an MqttException when publishing using test client"); + fail("Should never get an MqttException when publishing using test client"); } } } diff --git a/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/test/java/org/apache/nifi/processors/mqtt/common/TestConsumeMqttCommon.java b/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/test/java/org/apache/nifi/processors/mqtt/common/TestConsumeMqttCommon.java index d010d1d7c8..a9159ad874 100644 --- a/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/test/java/org/apache/nifi/processors/mqtt/common/TestConsumeMqttCommon.java +++ b/nifi-nar-bundles/nifi-mqtt-bundle/nifi-mqtt-processors/src/test/java/org/apache/nifi/processors/mqtt/common/TestConsumeMqttCommon.java @@ -20,6 +20,7 @@ package org.apache.nifi.processors.mqtt.common; import io.moquette.proto.messages.AbstractMessage; import io.moquette.proto.messages.PublishMessage; import io.moquette.server.Server; +import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processors.mqtt.ConsumeMQTT; import org.apache.nifi.provenance.ProvenanceEventRecord; import org.apache.nifi.provenance.ProvenanceEventType; @@ -34,6 +35,7 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.nio.ByteBuffer; import java.util.List; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import static org.apache.nifi.processors.mqtt.ConsumeMQTT.BROKER_ATTRIBUTE_KEY; @@ -400,6 +402,18 @@ public abstract class TestConsumeMqttCommon { method.invoke(processor); } + public static BlockingQueue getMqttQueue(ConsumeMQTT consumeMQTT) throws IllegalAccessException, NoSuchFieldException { + Field mqttQueueField = ConsumeMQTT.class.getDeclaredField("mqttQueue"); + mqttQueueField.setAccessible(true); + return (BlockingQueue) mqttQueueField.get(consumeMQTT); + } + + public static void transferQueue(ConsumeMQTT consumeMQTT, ProcessSession session) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method transferQueue = ConsumeMQTT.class.getDeclaredMethod("transferQueue", ProcessSession.class); + transferQueue.setAccessible(true); + transferQueue.invoke(consumeMQTT, session); + } + private void assertProvenanceEvents(int count){ List provenanceEvents = testRunner.getProvenanceEvents(); assertNotNull(provenanceEvents);