From 67876f707b7a0547e2be9bb270eb84c14ac3c718 Mon Sep 17 00:00:00 2001 From: Bryan Bende Date: Mon, 25 Jul 2016 21:19:40 -0400 Subject: [PATCH] NIFI-2399 Correcting comparison of maxEventId against lastEventId in SiteToSiteProvenanceReportingTask --- .../SiteToSiteProvenanceReportingTask.java | 6 +- ...TestSiteToSiteProvenanceReportingTask.java | 131 ++++++++++++------ 2 files changed, 95 insertions(+), 42 deletions(-) diff --git a/nifi-nar-bundles/nifi-site-to-site-reporting-bundle/nifi-site-to-site-reporting-task/src/main/java/org/apache/nifi/reporting/SiteToSiteProvenanceReportingTask.java b/nifi-nar-bundles/nifi-site-to-site-reporting-bundle/nifi-site-to-site-reporting-task/src/main/java/org/apache/nifi/reporting/SiteToSiteProvenanceReportingTask.java index a6eb66295b..8c2bd337ac 100644 --- a/nifi-nar-bundles/nifi-site-to-site-reporting-bundle/nifi-site-to-site-reporting-task/src/main/java/org/apache/nifi/reporting/SiteToSiteProvenanceReportingTask.java +++ b/nifi-nar-bundles/nifi-site-to-site-reporting-bundle/nifi-site-to-site-reporting-task/src/main/java/org/apache/nifi/reporting/SiteToSiteProvenanceReportingTask.java @@ -60,8 +60,8 @@ import java.util.concurrent.TimeUnit; @Stateful(scopes = Scope.LOCAL, description = "Stores the Reporting Task's last event Id so that on restart the task knows where it left off.") public class SiteToSiteProvenanceReportingTask extends AbstractSiteToSiteReportingTask { - private static final String TIMESTAMP_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"; - private static final String LAST_EVENT_ID_KEY = "last_event_id"; + static final String TIMESTAMP_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"; + static final String LAST_EVENT_ID_KEY = "last_event_id"; static final PropertyDescriptor PLATFORM = new PropertyDescriptor.Builder() .name("Platform") @@ -136,7 +136,7 @@ public class SiteToSiteProvenanceReportingTask extends AbstractSiteToSiteReporti firstEventId = Long.parseLong(state.get(LAST_EVENT_ID_KEY)) + 1; } - if(currMaxId < firstEventId){ + if(currMaxId < (firstEventId - 1)){ getLogger().warn("Current provenance max id is {} which is less than what was stored in state as the last queried event, which was {}. This means the provenance restarted its " + "ids. Restarting querying from the beginning.", new Object[]{currMaxId, firstEventId}); firstEventId = -1; diff --git a/nifi-nar-bundles/nifi-site-to-site-reporting-bundle/nifi-site-to-site-reporting-task/src/test/java/org/apache/nifi/reporting/TestSiteToSiteProvenanceReportingTask.java b/nifi-nar-bundles/nifi-site-to-site-reporting-bundle/nifi-site-to-site-reporting-task/src/test/java/org/apache/nifi/reporting/TestSiteToSiteProvenanceReportingTask.java index a048f5b211..493009416f 100644 --- a/nifi-nar-bundles/nifi-site-to-site-reporting-bundle/nifi-site-to-site-reporting-task/src/test/java/org/apache/nifi/reporting/TestSiteToSiteProvenanceReportingTask.java +++ b/nifi-nar-bundles/nifi-site-to-site-reporting-bundle/nifi-site-to-site-reporting-task/src/test/java/org/apache/nifi/reporting/TestSiteToSiteProvenanceReportingTask.java @@ -17,19 +17,10 @@ package org.apache.nifi.reporting; -import static org.junit.Assert.assertEquals; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.atomic.AtomicInteger; import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.components.PropertyValue; +import org.apache.nifi.components.state.Scope; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.provenance.ProvenanceEventBuilder; @@ -53,6 +44,16 @@ import org.mockito.stubbing.Answer; import javax.json.Json; import javax.json.JsonObject; import javax.json.JsonReader; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; public class TestSiteToSiteProvenanceReportingTask { @@ -78,33 +79,7 @@ public class TestSiteToSiteProvenanceReportingTask { builder.setComponentType("dummy processor"); final ProvenanceEventRecord event = builder.build(); - final List dataSent = new ArrayList<>(); - final SiteToSiteProvenanceReportingTask task = new SiteToSiteProvenanceReportingTask() { - @SuppressWarnings("unchecked") - @Override - protected SiteToSiteClient getClient() { - final SiteToSiteClient client = Mockito.mock(SiteToSiteClient.class); - final Transaction transaction = Mockito.mock(Transaction.class); - - try { - Mockito.doAnswer(new Answer() { - @Override - public Object answer(final InvocationOnMock invocation) throws Throwable { - final byte[] data = invocation.getArgumentAt(0, byte[].class); - dataSent.add(data); - return null; - } - }).when(transaction).send(Mockito.any(byte[].class), Mockito.any(Map.class)); - - Mockito.when(client.createTransaction(Mockito.any(TransferDirection.class))).thenReturn(transaction); - } catch (final Exception e) { - e.printStackTrace(); - Assert.fail(e.toString()); - } - - return client; - } - }; + final MockSiteToSiteProvenanceReportingTask task = new MockSiteToSiteProvenanceReportingTask(); final Map properties = new HashMap<>(); for (final PropertyDescriptor descriptor : task.getSupportedPropertyDescriptors()) { @@ -162,16 +137,94 @@ public class TestSiteToSiteProvenanceReportingTask { task.initialize(initContext); task.onTrigger(context); - assertEquals(3, dataSent.size()); - final String msg = new String(dataSent.get(0), StandardCharsets.UTF_8); + assertEquals(3, task.dataSent.size()); + final String msg = new String(task.dataSent.get(0), StandardCharsets.UTF_8); JsonReader jsonReader = Json.createReader(new ByteArrayInputStream(msg.getBytes())); JsonObject msgArray = jsonReader.readArray().getJsonObject(0).getJsonObject("updatedAttributes"); assertEquals(msgArray.getString("abc"), event.getAttributes().get("abc")); } + @Test + public void testWhenProvenanceMaxIdEqualToLastEventIdInStateManager() throws IOException, InitializationException { + final long maxEventId = 2500; + + // create the mock reporting task and mock state manager + final MockSiteToSiteProvenanceReportingTask task = new MockSiteToSiteProvenanceReportingTask(); + final MockStateManager stateManager = new MockStateManager(task); + + // create the state map and set the last id to the same value as maxEventId + final Map state = new HashMap<>(); + state.put(SiteToSiteProvenanceReportingTask.LAST_EVENT_ID_KEY, String.valueOf(maxEventId)); + stateManager.setState(state, Scope.LOCAL); + + // setup the mock reporting context to return the mock state manager + final ReportingContext context = Mockito.mock(ReportingContext.class); + Mockito.when(context.getStateManager()).thenReturn(stateManager); + + // setup the mock provenance repository to return maxEventId + final ProvenanceEventRepository provenanceRepository = Mockito.mock(ProvenanceEventRepository.class); + Mockito.doAnswer(new Answer() { + @Override + public Long answer(final InvocationOnMock invocation) throws Throwable { + return maxEventId; + } + }).when(provenanceRepository).getMaxEventId(); + + // setup the mock EventAccess to return the mock provenance repository + final EventAccess eventAccess = Mockito.mock(EventAccess.class); + Mockito.when(context.getEventAccess()).thenReturn(eventAccess); + Mockito.when(eventAccess.getProvenanceRepository()).thenReturn(provenanceRepository); + + // setup the mock initialization context + final ComponentLog logger = Mockito.mock(ComponentLog.class); + final ReportingInitializationContext initContext = Mockito.mock(ReportingInitializationContext.class); + Mockito.when(initContext.getIdentifier()).thenReturn(UUID.randomUUID().toString()); + Mockito.when(initContext.getLogger()).thenReturn(logger); + + task.initialize(initContext); + + // execute the reporting task and should not produce any data b/c max id same as previous id + task.onTrigger(context); + assertEquals(0, task.dataSent.size()); + } + public static FlowFile createFlowFile(final long id, final Map attributes) { MockFlowFile mockFlowFile = new MockFlowFile(id); mockFlowFile.putAttributes(attributes); return mockFlowFile; } + + private static final class MockSiteToSiteProvenanceReportingTask extends SiteToSiteProvenanceReportingTask { + + final List dataSent = new ArrayList<>(); + + @Override + protected SiteToSiteClient getClient() { + final SiteToSiteClient client = Mockito.mock(SiteToSiteClient.class); + final Transaction transaction = Mockito.mock(Transaction.class); + + try { + Mockito.doAnswer(new Answer() { + @Override + public Object answer(final InvocationOnMock invocation) throws Throwable { + final byte[] data = invocation.getArgumentAt(0, byte[].class); + dataSent.add(data); + return null; + } + }).when(transaction).send(Mockito.any(byte[].class), Mockito.any(Map.class)); + + Mockito.when(client.createTransaction(Mockito.any(TransferDirection.class))).thenReturn(transaction); + } catch (final Exception e) { + e.printStackTrace(); + Assert.fail(e.toString()); + } + + return client; + } + + public List getDataSent() { + return dataSent; + } + } + }