NIFI-4475 Changing the get(batchSize) method in StandardProcessSession so that it checks all connections before returning nothing. This closes #2337.

This commit is contained in:
Joe Percivall 2017-12-13 17:17:05 -05:00 committed by Mark Payne
parent 57947d64cd
commit 48ae4be015
2 changed files with 140 additions and 90 deletions

View File

@ -16,33 +16,6 @@
*/ */
package org.apache.nifi.controller.repository; package org.apache.nifi.controller.repository;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.nifi.connectable.Connectable; import org.apache.nifi.connectable.Connectable;
import org.apache.nifi.connectable.Connection; import org.apache.nifi.connectable.Connection;
@ -82,6 +55,33 @@ import org.apache.nifi.stream.io.StreamUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/** /**
* <p> * <p>
* Provides a ProcessSession that ensures all accesses, changes and transfers * Provides a ProcessSession that ensures all accesses, changes and transfers
@ -1533,9 +1533,7 @@ public final class StandardProcessSession implements ProcessSession, ProvenanceE
return Collections.emptyList(); return Collections.emptyList();
} }
final Connection connection = connections.get(context.getNextIncomingConnectionIndex() % connections.size()); return get(new ConnectionPoller() {
return get(connection, new ConnectionPoller() {
@Override @Override
public List<FlowFileRecord> poll(final Connection connection, final Set<FlowFileRecord> expiredRecords) { public List<FlowFileRecord> poll(final Connection connection, final Set<FlowFileRecord> expiredRecords) {
return connection.poll(new FlowFileFilter() { return connection.poll(new FlowFileFilter() {

View File

@ -16,47 +16,6 @@
*/ */
package org.apache.nifi.controller.repository; package org.apache.nifi.controller.repository;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.notNull;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Pattern;
import org.apache.nifi.connectable.Connectable; import org.apache.nifi.connectable.Connectable;
import org.apache.nifi.connectable.ConnectableType; import org.apache.nifi.connectable.ConnectableType;
import org.apache.nifi.connectable.Connection; import org.apache.nifi.connectable.Connection;
@ -96,6 +55,47 @@ import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Pattern;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.notNull;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class TestStandardProcessSession { public class TestStandardProcessSession {
private StandardProcessSession session; private StandardProcessSession session;
@ -194,23 +194,43 @@ public class TestStandardProcessSession {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private Connection createConnection() { private Connection createConnection() {
final Connection connection = Mockito.mock(Connection.class); AtomicReference<FlowFileQueue> queueReference = new AtomicReference<>(flowFileQueue);
Connection connection = createConnection(queueReference);
flowFileQueue = queueReference.get();
if (flowFileQueue == null) { return connection;
}
private FlowFileQueue createFlowFileQueueSpy(Connection connection) {
final FlowFileSwapManager swapManager = Mockito.mock(FlowFileSwapManager.class); final FlowFileSwapManager swapManager = Mockito.mock(FlowFileSwapManager.class);
final ProcessScheduler processScheduler = Mockito.mock(ProcessScheduler.class); final ProcessScheduler processScheduler = Mockito.mock(ProcessScheduler.class);
final StandardFlowFileQueue actualQueue = new StandardFlowFileQueue("1", connection, flowFileRepo, provenanceRepo, null, final StandardFlowFileQueue actualQueue = new StandardFlowFileQueue("1", connection, flowFileRepo, provenanceRepo, null,
processScheduler, swapManager, null, 10000); processScheduler, swapManager, null, 10000);
flowFileQueue = Mockito.spy(actualQueue); return Mockito.spy(actualQueue);
} }
when(connection.getFlowFileQueue()).thenReturn(flowFileQueue); @SuppressWarnings("unchecked")
private Connection createConnection(AtomicReference<FlowFileQueue> flowFileQueueReference) {
final Connection connection = Mockito.mock(Connection.class);
FlowFileQueue flowFileQueueFromReference = flowFileQueueReference.get();
final FlowFileQueue localFlowFileQueue;
if (flowFileQueueFromReference == null) {
localFlowFileQueue = createFlowFileQueueSpy(connection);
flowFileQueueReference.set(localFlowFileQueue);
} else {
localFlowFileQueue = flowFileQueueFromReference;
}
when(connection.getFlowFileQueue()).thenReturn(localFlowFileQueue);
Mockito.doAnswer(new Answer<Object>() { Mockito.doAnswer(new Answer<Object>() {
@Override @Override
public Object answer(InvocationOnMock invocation) throws Throwable { public Object answer(InvocationOnMock invocation) throws Throwable {
flowFileQueue.put((FlowFileRecord) invocation.getArguments()[0]); localFlowFileQueue.put((FlowFileRecord) invocation.getArguments()[0]);
return null; return null;
} }
}).when(connection).enqueue(Mockito.any(FlowFileRecord.class)); }).when(connection).enqueue(Mockito.any(FlowFileRecord.class));
@ -218,7 +238,7 @@ public class TestStandardProcessSession {
Mockito.doAnswer(new Answer<Object>() { Mockito.doAnswer(new Answer<Object>() {
@Override @Override
public Object answer(InvocationOnMock invocation) throws Throwable { public Object answer(InvocationOnMock invocation) throws Throwable {
flowFileQueue.putAll((Collection<FlowFileRecord>) invocation.getArguments()[0]); localFlowFileQueue.putAll((Collection<FlowFileRecord>) invocation.getArguments()[0]);
return null; return null;
} }
}).when(connection).enqueue(Mockito.any(Collection.class)); }).when(connection).enqueue(Mockito.any(Collection.class));
@ -230,14 +250,14 @@ public class TestStandardProcessSession {
Mockito.doAnswer(new Answer<FlowFile>() { Mockito.doAnswer(new Answer<FlowFile>() {
@Override @Override
public FlowFile answer(InvocationOnMock invocation) throws Throwable { public FlowFile answer(InvocationOnMock invocation) throws Throwable {
return flowFileQueue.poll(invocation.getArgumentAt(0, Set.class)); return localFlowFileQueue.poll(invocation.getArgumentAt(0, Set.class));
} }
}).when(connection).poll(any(Set.class)); }).when(connection).poll(any(Set.class));
Mockito.doAnswer(new Answer<List<FlowFileRecord>>() { Mockito.doAnswer(new Answer<List<FlowFileRecord>>() {
@Override @Override
public List<FlowFileRecord> answer(InvocationOnMock invocation) throws Throwable { public List<FlowFileRecord> answer(InvocationOnMock invocation) throws Throwable {
return flowFileQueue.poll(invocation.getArgumentAt(0, FlowFileFilter.class), invocation.getArgumentAt(1, Set.class)); return localFlowFileQueue.poll(invocation.getArgumentAt(0, FlowFileFilter.class), invocation.getArgumentAt(1, Set.class));
} }
}).when(connection).poll(any(FlowFileFilter.class), any(Set.class)); }).when(connection).poll(any(FlowFileFilter.class), any(Set.class));
@ -372,6 +392,38 @@ public class TestStandardProcessSession {
verify(conn2, times(1)).poll(any(FlowFileFilter.class), any(Set.class)); verify(conn2, times(1)).poll(any(FlowFileFilter.class), any(Set.class));
} }
@Test
@SuppressWarnings("unchecked")
public void testRoundRobinAcrossConnectionsOnSessionGetWithCount() {
final AtomicReference<FlowFileQueue> queue1Reference = new AtomicReference<>();
final AtomicReference<FlowFileQueue> queue2Reference = new AtomicReference<>();
final List<Connection> connList = new ArrayList<>();
final Connection conn1 = createConnection(queue1Reference);
final Connection conn2 = createConnection(queue2Reference);
connList.add(conn1);
connList.add(conn2);
final FlowFileQueue queue2 = queue2Reference.get();
final FlowFileRecord flowFileRecord = new StandardFlowFileRecord.Builder()
.id(1000L)
.addAttribute("uuid", "12345678-1234-1234-1234-123456789012")
.entryDate(System.currentTimeMillis())
.build();
queue2.put(flowFileRecord);
when(connectable.getIncomingConnections()).thenReturn(connList);
List<FlowFile> result = session.get(2);
assertEquals(1, result.size());
verify(conn1, times(1)).poll(any(FlowFileFilter.class), any(Set.class));
verify(conn2, times(1)).poll(any(FlowFileFilter.class), any(Set.class));
}
@Test @Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testRoundRobinOnSessionGetWithFilter() { public void testRoundRobinOnSessionGetWithFilter() {
@ -580,19 +632,19 @@ public class TestStandardProcessSession {
outputStream.write(new byte[0]); outputStream.write(new byte[0]);
Assert.fail("Expected OutputStream to be disabled; was able to call write(byte[])"); Assert.fail("Expected OutputStream to be disabled; was able to call write(byte[])");
} catch (final Exception ex) { } catch (final Exception ex) {
Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); assertEquals(FlowFileAccessException.class, ex.getClass());
} }
try { try {
outputStream.write(0); outputStream.write(0);
Assert.fail("Expected OutputStream to be disabled; was able to call write(int)"); Assert.fail("Expected OutputStream to be disabled; was able to call write(int)");
} catch (final Exception ex) { } catch (final Exception ex) {
Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); assertEquals(FlowFileAccessException.class, ex.getClass());
} }
try { try {
outputStream.write(new byte[0], 0, 0); outputStream.write(new byte[0], 0, 0);
Assert.fail("Expected OutputStream to be disabled; was able to call write(byte[], int, int)"); Assert.fail("Expected OutputStream to be disabled; was able to call write(byte[], int, int)");
} catch (final Exception ex) { } catch (final Exception ex) {
Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); assertEquals(FlowFileAccessException.class, ex.getClass());
} }
} }
@ -601,31 +653,31 @@ public class TestStandardProcessSession {
inputStream.read(); inputStream.read();
Assert.fail("Expected InputStream to be disabled; was able to call read()"); Assert.fail("Expected InputStream to be disabled; was able to call read()");
} catch (final Exception ex) { } catch (final Exception ex) {
Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); assertEquals(FlowFileAccessException.class, ex.getClass());
} }
try { try {
inputStream.read(new byte[0]); inputStream.read(new byte[0]);
Assert.fail("Expected InputStream to be disabled; was able to call read(byte[])"); Assert.fail("Expected InputStream to be disabled; was able to call read(byte[])");
} catch (final Exception ex) { } catch (final Exception ex) {
Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); assertEquals(FlowFileAccessException.class, ex.getClass());
} }
try { try {
inputStream.read(new byte[0], 0, 0); inputStream.read(new byte[0], 0, 0);
Assert.fail("Expected InputStream to be disabled; was able to call read(byte[], int, int)"); Assert.fail("Expected InputStream to be disabled; was able to call read(byte[], int, int)");
} catch (final Exception ex) { } catch (final Exception ex) {
Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); assertEquals(FlowFileAccessException.class, ex.getClass());
} }
try { try {
inputStream.reset(); inputStream.reset();
Assert.fail("Expected InputStream to be disabled; was able to call reset()"); Assert.fail("Expected InputStream to be disabled; was able to call reset()");
} catch (final Exception ex) { } catch (final Exception ex) {
Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); assertEquals(FlowFileAccessException.class, ex.getClass());
} }
try { try {
inputStream.skip(1L); inputStream.skip(1L);
Assert.fail("Expected InputStream to be disabled; was able to call skip(long)"); Assert.fail("Expected InputStream to be disabled; was able to call skip(long)");
} catch (final Exception ex) { } catch (final Exception ex) {
Assert.assertEquals(FlowFileAccessException.class, ex.getClass()); assertEquals(FlowFileAccessException.class, ex.getClass());
} }
} }