From 721628eb955f0f52d8ece0db2afe784c0a144008 Mon Sep 17 00:00:00 2001 From: Eric Secules Date: Tue, 3 Oct 2023 11:40:31 -0700 Subject: [PATCH] NIFI-12158 MockProcessSession write methods preserves attributes (#7828) Co-authored-by: Eric Secules --- .../org/apache/nifi/util/MockProcessSession.java | 13 ++++++------- .../apache/nifi/util/TestMockProcessSession.java | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/nifi-mock/src/main/java/org/apache/nifi/util/MockProcessSession.java b/nifi-mock/src/main/java/org/apache/nifi/util/MockProcessSession.java index 4ab99220a2..09cc1d3a5f 100644 --- a/nifi-mock/src/main/java/org/apache/nifi/util/MockProcessSession.java +++ b/nifi-mock/src/main/java/org/apache/nifi/util/MockProcessSession.java @@ -919,16 +919,15 @@ public class MockProcessSession implements ProcessSession { if (!(flowFile instanceof MockFlowFile)) { throw new IllegalArgumentException("Cannot export a flow file that I did not create"); } - final MockFlowFile mockFlowFile = validateState(flowFile); - writeRecursionSet.add(flowFile); + writeRecursionSet.add(mockFlowFile); final ByteArrayOutputStream baos = new ByteArrayOutputStream() { @Override public void close() throws IOException { super.close(); writeRecursionSet.remove(mockFlowFile); - final MockFlowFile newFlowFile = new MockFlowFile(mockFlowFile.getId(), flowFile); + final MockFlowFile newFlowFile = new MockFlowFile(mockFlowFile.getId(), mockFlowFile); currentVersions.put(newFlowFile.getId(), newFlowFile); newFlowFile.setData(toByteArray()); @@ -961,12 +960,12 @@ public class MockProcessSession implements ProcessSession { } @Override - public MockFlowFile write(final FlowFile flowFile, final StreamCallback callback) { + public MockFlowFile write(FlowFile flowFile, final StreamCallback callback) { + flowFile = validateState(flowFile); if (callback == null || flowFile == null) { throw new IllegalArgumentException("argument cannot be null"); } - final MockFlowFile mock = validateState(flowFile); - + final MockFlowFile mock = (MockFlowFile) flowFile; final ByteArrayInputStream in = new ByteArrayInputStream(mock.getData()); final ByteArrayOutputStream out = new ByteArrayOutputStream(); @@ -979,7 +978,7 @@ public class MockProcessSession implements ProcessSession { writeRecursionSet.remove(flowFile); } - final MockFlowFile newFlowFile = new MockFlowFile(mock.getId(), flowFile); + final MockFlowFile newFlowFile = new MockFlowFile(flowFile.getId(), flowFile); currentVersions.put(newFlowFile.getId(), newFlowFile); newFlowFile.setData(out.toByteArray()); diff --git a/nifi-mock/src/test/java/org/apache/nifi/util/TestMockProcessSession.java b/nifi-mock/src/test/java/org/apache/nifi/util/TestMockProcessSession.java index eefd4a39dd..775bc2f5ed 100644 --- a/nifi-mock/src/test/java/org/apache/nifi/util/TestMockProcessSession.java +++ b/nifi-mock/src/test/java/org/apache/nifi/util/TestMockProcessSession.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.Test; import java.io.IOException; import java.io.InputStream; import java.util.Collections; +import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicLong; @@ -133,6 +134,20 @@ public class TestMockProcessSession { assertFalse(ff1.isPenalized()); } + @Test + public void testAttributePreservedAfterWrite() throws IOException { + final Processor processor = new PoorlyBehavedProcessor(); + final MockProcessSession session = new MockProcessSession(new SharedSessionState(processor, new AtomicLong(0L)), processor, new MockStateManager(processor)); + FlowFile ff1 = session.createFlowFile("hello, world".getBytes()); + session.putAttribute(ff1, "key1", "val1"); + session.write(ff1).close(); + session.transfer(ff1, PoorlyBehavedProcessor.REL_FAILURE); + session.commitAsync(); + List output = session.getFlowFilesForRelationship(PoorlyBehavedProcessor.REL_FAILURE); + assertEquals(1, output.size()); + output.get(0).assertAttributeEquals("key1", "val1"); + } + protected static class PoorlyBehavedProcessor extends AbstractProcessor { private static final Relationship REL_FAILURE = new Relationship.Builder()