diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/repository/StandardProcessSession.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/repository/StandardProcessSession.java index c4ea132b50..24fbf72312 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/repository/StandardProcessSession.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/repository/StandardProcessSession.java @@ -16,33 +16,6 @@ */ package org.apache.nifi.controller.repository; -import java.io.BufferedOutputStream; -import java.io.ByteArrayInputStream; -import java.io.Closeable; -import java.io.EOFException; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Objects; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; -import java.util.regex.Pattern; -import java.util.stream.Collectors; - import org.apache.commons.io.IOUtils; import org.apache.nifi.connectable.Connectable; import org.apache.nifi.connectable.Connection; @@ -82,6 +55,33 @@ import org.apache.nifi.stream.io.StreamUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.BufferedOutputStream; +import java.io.ByteArrayInputStream; +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + /** *

* Provides a ProcessSession that ensures all accesses, changes and transfers @@ -1533,9 +1533,7 @@ public final class StandardProcessSession implements ProcessSession, ProvenanceE return Collections.emptyList(); } - final Connection connection = connections.get(context.getNextIncomingConnectionIndex() % connections.size()); - - return get(connection, new ConnectionPoller() { + return get(new ConnectionPoller() { @Override public List poll(final Connection connection, final Set expiredRecords) { return connection.poll(new FlowFileFilter() { diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/repository/TestStandardProcessSession.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/repository/TestStandardProcessSession.java index 68d13e322e..cf91ae2a04 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/repository/TestStandardProcessSession.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/repository/TestStandardProcessSession.java @@ -16,47 +16,6 @@ */ package org.apache.nifi.controller.repository; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.notNull; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.FileOutputStream; -import java.io.FilterOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import java.util.regex.Pattern; - import org.apache.nifi.connectable.Connectable; import org.apache.nifi.connectable.ConnectableType; import org.apache.nifi.connectable.Connection; @@ -96,6 +55,47 @@ import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Pattern; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.notNull; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + public class TestStandardProcessSession { private StandardProcessSession session; @@ -194,23 +194,43 @@ public class TestStandardProcessSession { @SuppressWarnings("unchecked") private Connection createConnection() { + AtomicReference queueReference = new AtomicReference<>(flowFileQueue); + Connection connection = createConnection(queueReference); + flowFileQueue = queueReference.get(); + + return connection; + } + + private FlowFileQueue createFlowFileQueueSpy(Connection connection) { + final FlowFileSwapManager swapManager = Mockito.mock(FlowFileSwapManager.class); + final ProcessScheduler processScheduler = Mockito.mock(ProcessScheduler.class); + + final StandardFlowFileQueue actualQueue = new StandardFlowFileQueue("1", connection, flowFileRepo, provenanceRepo, null, + processScheduler, swapManager, null, 10000); + return Mockito.spy(actualQueue); + } + + @SuppressWarnings("unchecked") + private Connection createConnection(AtomicReference flowFileQueueReference) { final Connection connection = Mockito.mock(Connection.class); - if (flowFileQueue == null) { - final FlowFileSwapManager swapManager = Mockito.mock(FlowFileSwapManager.class); - final ProcessScheduler processScheduler = Mockito.mock(ProcessScheduler.class); + FlowFileQueue flowFileQueueFromReference = flowFileQueueReference.get(); - final StandardFlowFileQueue actualQueue = new StandardFlowFileQueue("1", connection, flowFileRepo, provenanceRepo, null, - processScheduler, swapManager, null, 10000); - flowFileQueue = Mockito.spy(actualQueue); + final FlowFileQueue localFlowFileQueue; + + if (flowFileQueueFromReference == null) { + localFlowFileQueue = createFlowFileQueueSpy(connection); + flowFileQueueReference.set(localFlowFileQueue); + } else { + localFlowFileQueue = flowFileQueueFromReference; } - when(connection.getFlowFileQueue()).thenReturn(flowFileQueue); + when(connection.getFlowFileQueue()).thenReturn(localFlowFileQueue); Mockito.doAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { - flowFileQueue.put((FlowFileRecord) invocation.getArguments()[0]); + localFlowFileQueue.put((FlowFileRecord) invocation.getArguments()[0]); return null; } }).when(connection).enqueue(Mockito.any(FlowFileRecord.class)); @@ -218,7 +238,7 @@ public class TestStandardProcessSession { Mockito.doAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocation) throws Throwable { - flowFileQueue.putAll((Collection) invocation.getArguments()[0]); + localFlowFileQueue.putAll((Collection) invocation.getArguments()[0]); return null; } }).when(connection).enqueue(Mockito.any(Collection.class)); @@ -230,14 +250,14 @@ public class TestStandardProcessSession { Mockito.doAnswer(new Answer() { @Override public FlowFile answer(InvocationOnMock invocation) throws Throwable { - return flowFileQueue.poll(invocation.getArgumentAt(0, Set.class)); + return localFlowFileQueue.poll(invocation.getArgumentAt(0, Set.class)); } }).when(connection).poll(any(Set.class)); Mockito.doAnswer(new Answer>() { @Override public List answer(InvocationOnMock invocation) throws Throwable { - return flowFileQueue.poll(invocation.getArgumentAt(0, FlowFileFilter.class), invocation.getArgumentAt(1, Set.class)); + return localFlowFileQueue.poll(invocation.getArgumentAt(0, FlowFileFilter.class), invocation.getArgumentAt(1, Set.class)); } }).when(connection).poll(any(FlowFileFilter.class), any(Set.class)); @@ -372,6 +392,38 @@ public class TestStandardProcessSession { verify(conn2, times(1)).poll(any(FlowFileFilter.class), any(Set.class)); } + @Test + @SuppressWarnings("unchecked") + public void testRoundRobinAcrossConnectionsOnSessionGetWithCount() { + final AtomicReference queue1Reference = new AtomicReference<>(); + final AtomicReference queue2Reference = new AtomicReference<>(); + + final List connList = new ArrayList<>(); + final Connection conn1 = createConnection(queue1Reference); + final Connection conn2 = createConnection(queue2Reference); + connList.add(conn1); + connList.add(conn2); + + final FlowFileQueue queue2 = queue2Reference.get(); + + final FlowFileRecord flowFileRecord = new StandardFlowFileRecord.Builder() + .id(1000L) + .addAttribute("uuid", "12345678-1234-1234-1234-123456789012") + .entryDate(System.currentTimeMillis()) + .build(); + + queue2.put(flowFileRecord); + + when(connectable.getIncomingConnections()).thenReturn(connList); + + List result = session.get(2); + + assertEquals(1, result.size()); + + verify(conn1, times(1)).poll(any(FlowFileFilter.class), any(Set.class)); + verify(conn2, times(1)).poll(any(FlowFileFilter.class), any(Set.class)); + } + @Test @SuppressWarnings("unchecked") public void testRoundRobinOnSessionGetWithFilter() { @@ -580,19 +632,19 @@ public class TestStandardProcessSession { outputStream.write(new byte[0]); Assert.fail("Expected OutputStream to be disabled; was able to call write(byte[])"); } catch (final Exception ex) { - Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); + assertEquals(FlowFileAccessException.class, ex.getClass()); } try { outputStream.write(0); Assert.fail("Expected OutputStream to be disabled; was able to call write(int)"); } catch (final Exception ex) { - Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); + assertEquals(FlowFileAccessException.class, ex.getClass()); } try { outputStream.write(new byte[0], 0, 0); Assert.fail("Expected OutputStream to be disabled; was able to call write(byte[], int, int)"); } catch (final Exception ex) { - Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); + assertEquals(FlowFileAccessException.class, ex.getClass()); } } @@ -601,31 +653,31 @@ public class TestStandardProcessSession { inputStream.read(); Assert.fail("Expected InputStream to be disabled; was able to call read()"); } catch (final Exception ex) { - Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); + assertEquals(FlowFileAccessException.class, ex.getClass()); } try { inputStream.read(new byte[0]); Assert.fail("Expected InputStream to be disabled; was able to call read(byte[])"); } catch (final Exception ex) { - Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); + assertEquals(FlowFileAccessException.class, ex.getClass()); } try { inputStream.read(new byte[0], 0, 0); Assert.fail("Expected InputStream to be disabled; was able to call read(byte[], int, int)"); } catch (final Exception ex) { - Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); + assertEquals(FlowFileAccessException.class, ex.getClass()); } try { inputStream.reset(); Assert.fail("Expected InputStream to be disabled; was able to call reset()"); } catch (final Exception ex) { - Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); + assertEquals(FlowFileAccessException.class, ex.getClass()); } try { inputStream.skip(1L); Assert.fail("Expected InputStream to be disabled; was able to call skip(long)"); } catch (final Exception ex) { - Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); + assertEquals(FlowFileAccessException.class, ex.getClass()); } }