NIFI-9207 - Added Max Read Size to Distributed Cache Servers

This closes #5420

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Paul Grey 2021-09-09 11:39:21 -04:00 committed by exceptionfactory
parent efc1cb012f
commit 46a5e3f096
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
10 changed files with 126 additions and 33 deletions

View File

@ -45,7 +45,7 @@ public class CacheClientRequestHandler extends ChannelInboundHandlerAdapter {
private ChannelPromise channelPromise;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws IOException {
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
final ByteBuf byteBuf = (ByteBuf) msg;
try {
final byte[] bytes = new byte[byteBuf.readableBytes()];
@ -57,13 +57,25 @@ public class CacheClientRequestHandler extends ChannelInboundHandlerAdapter {
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws IOException {
public void channelReadComplete(final ChannelHandlerContext ctx) throws IOException {
inboundAdapter.dequeue();
if (inboundAdapter.isComplete() && !channelPromise.isSuccess()) {
channelPromise.setSuccess();
}
}
@Override
public void channelUnregistered(final ChannelHandlerContext ctx) {
if (!inboundAdapter.isComplete()) {
channelPromise.setFailure(new IOException("Channel unregistered before processing completed: " + ctx.channel().toString()));
}
}
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
channelPromise.setFailure(cause);
}
/**
* Perform a synchronous method call to the server. The server is expected to write
* a byte stream response to the channel, which may be deserialized into a Java object
@ -86,5 +98,8 @@ public class CacheClientRequestHandler extends ChannelInboundHandlerAdapter {
channel.writeAndFlush(Unpooled.wrappedBuffer(outboundAdapter.toBytes()));
channelPromise.awaitUninterruptibly();
this.inboundAdapter = new NullInboundAdapter();
if (channelPromise.cause() != null) {
throw new IOException("Request invocation failed", channelPromise.cause());
}
}
}

View File

@ -84,7 +84,7 @@ public class DistributedMapCacheClientService extends AbstractControllerService
.build();
/**
* The implementation of the business logic for {@link DistributedSetCacheClientService}.
* The implementation of the business logic for {@link DistributedMapCacheClientService}.
*/
private volatile NettyDistributedMapCacheClient cacheClient = null;

View File

@ -32,7 +32,7 @@ public class NullInboundAdapter implements InboundAdapter {
@Override
public boolean isComplete() {
return false;
return true;
}
@Override

View File

@ -18,6 +18,7 @@ package org.apache.nifi.distributed.cache.server;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@ -49,16 +50,18 @@ public abstract class AbstractCacheServer implements CacheServer {
private final String identifier;
private final int port;
private final int maxReadSize;
private final SSLContext sslContext;
protected volatile boolean stopped = false;
private final Set<Thread> processInputThreads = new CopyOnWriteArraySet<>();
private volatile ServerSocketChannel serverSocketChannel;
public AbstractCacheServer(final String identifier, final SSLContext sslContext, final int port) {
public AbstractCacheServer(final String identifier, final SSLContext sslContext, final int port, final int maxReadSize) {
this.identifier = identifier;
this.port = port;
this.sslContext = sslContext;
this.maxReadSize = maxReadSize;
}
@Override
@ -108,14 +111,14 @@ public abstract class AbstractCacheServer implements CacheServer {
rawInputStream = new SSLSocketChannelInputStream(sslSocketChannel);
rawOutputStream = new SSLSocketChannelOutputStream(sslSocketChannel);
}
} catch (IOException e) {
} catch (final IOException e) {
logger.error("Cannot create input and/or output streams for {}", new Object[]{identifier}, e);
if (logger.isDebugEnabled()) {
logger.error("", e);
}
try {
socketChannel.close();
} catch (IOException swallow) {
} catch (final IOException swallow) {
}
return;
@ -179,19 +182,19 @@ public abstract class AbstractCacheServer implements CacheServer {
if (serverSocketChannel != null && serverSocketChannel.isOpen()) {
try {
serverSocketChannel.close();
} catch (IOException e) {
} catch (final IOException e) {
logger.warn("Server Socket Close Failed", e);
}
}
// need to close out the created SocketChannels...this is done by interrupting
// the created threads that loop on listen().
for (Thread processInputThread : processInputThreads) {
for (final Thread processInputThread : processInputThreads) {
processInputThread.interrupt();
int i = 0;
while (!processInputThread.isInterrupted() && i++ < 5) {
try {
Thread.sleep(50); // allow thread to gracefully terminate
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
}
}
}
@ -213,4 +216,32 @@ public abstract class AbstractCacheServer implements CacheServer {
* @throws IOException ex
*/
protected abstract boolean listen(InputStream in, OutputStream out, int version) throws IOException;
/**
* Read a length-prefixed value from the {@link DataInputStream}.
*
* @param dis the {@link DataInputStream} from which to read the value
* @return the serialized representation of the value
* @throws IOException on failure to read from the input stream
*/
protected byte[] readValue(final DataInputStream dis) throws IOException {
final int numBytes = validateSize(dis.readInt());
final byte[] buffer = new byte[numBytes];
dis.readFully(buffer);
return buffer;
}
/**
* Validate a size value received from the {@link DataInputStream} against the configured maximum.
*
* @param size the size value received from the {@link DataInputStream}
* @return the size value, iff it passes validation; otherwise, an exception is thrown
*/
protected int validateSize(final int size) {
if (size <= maxReadSize) {
return size;
} else {
throw new IllegalStateException(String.format("Size [%d] exceeds maximum configured read [%d]", size, maxReadSize));
}
}
}

View File

@ -68,6 +68,14 @@ public abstract class DistributedCacheServer extends AbstractControllerService {
.required(false)
.addValidator(StandardValidators.createDirectoryExistsValidator(true, true))
.build();
public static final PropertyDescriptor MAX_READ_SIZE = new PropertyDescriptor.Builder()
.name("maximum-read-size")
.displayName("Maximum Read Size")
.description("The maximum number of network bytes to read for a single cache item")
.required(false)
.addValidator(StandardValidators.DATA_SIZE_VALIDATOR)
.defaultValue("1 MB")
.build();
private volatile CacheServer cacheServer;
@ -79,6 +87,7 @@ public abstract class DistributedCacheServer extends AbstractControllerService {
properties.add(EVICTION_POLICY);
properties.add(PERSISTENCE_PATH);
properties.add(SSL_CONTEXT_SERVICE);
properties.add(MAX_READ_SIZE);
return properties;
}

View File

@ -21,6 +21,7 @@ import javax.net.ssl.SSLContext;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.processor.DataUnit;
import org.apache.nifi.ssl.SSLContextService;
@Tags({"distributed", "set", "distinct", "cache", "server"})
@ -35,6 +36,7 @@ public class DistributedSetCacheServer extends DistributedCacheServer {
final SSLContextService sslContextService = context.getProperty(SSL_CONTEXT_SERVICE).asControllerService(SSLContextService.class);
final int maxSize = context.getProperty(MAX_CACHE_ENTRIES).asInteger();
final String evictionPolicyName = context.getProperty(EVICTION_POLICY).getValue();
final int maxReadSize = context.getProperty(MAX_READ_SIZE).asDataSize(DataUnit.B).intValue();
final SSLContext sslContext;
if (sslContextService == null) {
@ -61,7 +63,7 @@ public class DistributedSetCacheServer extends DistributedCacheServer {
try {
final File persistenceDir = persistencePath == null ? null : new File(persistencePath);
return new SetCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir);
return new SetCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir, maxReadSize);
} catch (final Exception e) {
throw new RuntimeException(e);
}

View File

@ -36,8 +36,8 @@ public class SetCacheServer extends AbstractCacheServer {
private final SetCache cache;
public SetCacheServer(final String identifier, final SSLContext sslContext, final int port, final int maxSize,
final EvictionPolicy evictionPolicy, final File persistencePath) throws IOException {
super(identifier, sslContext, port);
final EvictionPolicy evictionPolicy, final File persistencePath, final int maxReadSize) throws IOException {
super(identifier, sslContext, port, maxReadSize);
final SetCache simpleCache = new SimpleSetCache(identifier, maxSize, evictionPolicy);
@ -60,9 +60,7 @@ public class SetCacheServer extends AbstractCacheServer {
return false;
}
final int valueLength = dis.readInt();
final byte[] value = new byte[valueLength];
dis.readFully(value);
final byte[] value = readValue(dis);
final ByteBuffer valueBuffer = ByteBuffer.wrap(value);
final SetCacheResult response;
@ -101,5 +99,4 @@ public class SetCacheServer extends AbstractCacheServer {
stop();
}
}
}

View File

@ -26,6 +26,7 @@ import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.distributed.cache.server.CacheServer;
import org.apache.nifi.distributed.cache.server.DistributedCacheServer;
import org.apache.nifi.distributed.cache.server.EvictionPolicy;
import org.apache.nifi.processor.DataUnit;
import org.apache.nifi.ssl.SSLContextService;
@Tags({"distributed", "cluster", "map", "cache", "server", "key/value"})
@ -41,6 +42,7 @@ public class DistributedMapCacheServer extends DistributedCacheServer {
final SSLContextService sslContextService = context.getProperty(SSL_CONTEXT_SERVICE).asControllerService(SSLContextService.class);
final int maxSize = context.getProperty(MAX_CACHE_ENTRIES).asInteger();
final String evictionPolicyName = context.getProperty(EVICTION_POLICY).getValue();
final int maxReadSize = context.getProperty(MAX_READ_SIZE).asDataSize(DataUnit.B).intValue();
final SSLContext sslContext;
if (sslContextService == null) {
@ -67,14 +69,16 @@ public class DistributedMapCacheServer extends DistributedCacheServer {
try {
final File persistenceDir = persistencePath == null ? null : new File(persistencePath);
return createMapCacheServer(port, maxSize, sslContext, evictionPolicy, persistenceDir);
return createMapCacheServer(port, maxSize, sslContext, evictionPolicy, persistenceDir, maxReadSize);
} catch (final Exception e) {
throw new RuntimeException(e);
}
}
protected MapCacheServer createMapCacheServer(int port, int maxSize, SSLContext sslContext, EvictionPolicy evictionPolicy, File persistenceDir) throws IOException {
return new MapCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir);
protected MapCacheServer createMapCacheServer(
final int port, final int maxSize, final SSLContext sslContext, final EvictionPolicy evictionPolicy,
final File persistenceDir, final int maxReadSize) throws IOException {
return new MapCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir, maxReadSize);
}
}

View File

@ -38,8 +38,8 @@ public class MapCacheServer extends AbstractCacheServer {
private final MapCache cache;
public MapCacheServer(final String identifier, final SSLContext sslContext, final int port, final int maxSize,
final EvictionPolicy evictionPolicy, final File persistencePath) throws IOException {
super(identifier, sslContext, port);
final EvictionPolicy evictionPolicy, final File persistencePath, final int maxReadSize) throws IOException {
super(identifier, sslContext, port, maxReadSize);
final MapCache simpleCache = new SimpleMapCache(identifier, maxSize, evictionPolicy);
@ -123,7 +123,7 @@ public class MapCacheServer extends AbstractCacheServer {
break;
}
case "subMap": {
final int numKeys = dis.readInt();
final int numKeys = validateSize(dis.readInt());
for(int i=0;i<numKeys;i++) {
final byte[] key = readValue(dis);
final ByteBuffer existingValue = cache.get(ByteBuffer.wrap(key));
@ -249,12 +249,4 @@ public class MapCacheServer extends AbstractCacheServer {
stop();
}
}
private byte[] readValue(final DataInputStream dis) throws IOException {
final int numBytes = dis.readInt();
final byte[] buffer = new byte[numBytes];
dis.readFully(buffer);
return buffer;
}
}

View File

@ -19,6 +19,7 @@ package org.apache.nifi.distributed.cache.server;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -33,6 +34,7 @@ import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.SerializationException;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.distributed.cache.client.AtomicCacheEntry;
import org.apache.nifi.distributed.cache.client.Deserializer;
@ -42,12 +44,14 @@ import org.apache.nifi.distributed.cache.client.Serializer;
import org.apache.nifi.distributed.cache.client.exception.DeserializationException;
import org.apache.nifi.distributed.cache.server.map.DistributedMapCacheServer;
import org.apache.nifi.distributed.cache.server.map.MapCacheServer;
import org.apache.nifi.processor.DataUnit;
import org.apache.nifi.processor.Processor;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.remote.StandardVersionNegotiator;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.MockConfigurationContext;
import org.apache.nifi.util.MockControllerServiceInitializationContext;
import org.apache.nifi.util.MockPropertyValue;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.Test;
@ -597,8 +601,8 @@ public class TestServerAndClient {
// Create a server that only supports protocol version 1.
final DistributedMapCacheServer server = new MapServer() {
@Override
protected MapCacheServer createMapCacheServer(int port, int maxSize, SSLContext sslContext, EvictionPolicy evictionPolicy, File persistenceDir) throws IOException {
return new MapCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir) {
protected MapCacheServer createMapCacheServer(int port, int maxSize, SSLContext sslContext, EvictionPolicy evictionPolicy, File persistenceDir, int maxReadSize) throws IOException {
return new MapCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir, maxReadSize) {
@Override
protected StandardVersionNegotiator getVersionNegotiator() {
return new StandardVersionNegotiator(1);
@ -666,6 +670,45 @@ public class TestServerAndClient {
server.shutdownServer();
}
@Test
public void testLimitServiceReadSizeMap() throws InitializationException, IOException {
final TestRunner runner = TestRunners.newTestRunner(Mockito.mock(Processor.class));
final DistributedMapCacheServer server = new MapServer();
runner.addControllerService("server", server);
runner.enableControllerService(server);
final DistributedMapCacheClientService client = createMapClient(server.getPort());
final Serializer<String> serializer = new StringSerializer();
final String key = "key";
final int maxReadSize = new MockPropertyValue(DistributedCacheServer.MAX_READ_SIZE.getDefaultValue()).asDataSize(DataUnit.B).intValue();
final int belowThreshold = maxReadSize / key.length();
final int aboveThreshold = belowThreshold + 1;
final String keyBelowThreshold = StringUtils.repeat(key, belowThreshold);
final String keyAboveThreshold = StringUtils.repeat(key, aboveThreshold);
assertFalse(client.containsKey(keyBelowThreshold, serializer));
assertThrows(IOException.class, () -> client.containsKey(keyAboveThreshold, serializer));
}
@Test
public void testLimitServiceReadSizeSet() throws InitializationException, IOException {
final TestRunner runner = TestRunners.newTestRunner(Mockito.mock(Processor.class));
final DistributedSetCacheServer server = new SetServer();
runner.addControllerService("server", server);
runner.enableControllerService(server);
final DistributedSetCacheClientService client = createClient(server.getPort());
final Serializer<String> serializer = new StringSerializer();
final String value = "value";
final int maxReadSize = new MockPropertyValue(DistributedCacheServer.MAX_READ_SIZE.getDefaultValue()).asDataSize(DataUnit.B).intValue();
final int belowThreshold = maxReadSize / value.length();
final int aboveThreshold = belowThreshold + 1;
final String valueBelowThreshold = StringUtils.repeat(value, belowThreshold);
final String valueAboveThreshold = StringUtils.repeat(value, aboveThreshold);
assertFalse(client.contains(valueBelowThreshold, serializer));
assertThrows(IOException.class, () -> client.contains(valueAboveThreshold, serializer));
}
private void waitABit() {
try {