mirror of
https://github.com/apache/nifi.git
synced 2025-02-09 11:35:05 +00:00
NIFI-9207 - Added Max Read Size to Distributed Cache Servers
This closes #5420 Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
parent
efc1cb012f
commit
46a5e3f096
@ -45,7 +45,7 @@ public class CacheClientRequestHandler extends ChannelInboundHandlerAdapter {
|
||||
private ChannelPromise channelPromise;
|
||||
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws IOException {
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
|
||||
final ByteBuf byteBuf = (ByteBuf) msg;
|
||||
try {
|
||||
final byte[] bytes = new byte[byteBuf.readableBytes()];
|
||||
@ -57,13 +57,25 @@ public class CacheClientRequestHandler extends ChannelInboundHandlerAdapter {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelReadComplete(ChannelHandlerContext ctx) throws IOException {
|
||||
public void channelReadComplete(final ChannelHandlerContext ctx) throws IOException {
|
||||
inboundAdapter.dequeue();
|
||||
if (inboundAdapter.isComplete() && !channelPromise.isSuccess()) {
|
||||
channelPromise.setSuccess();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelUnregistered(final ChannelHandlerContext ctx) {
|
||||
if (!inboundAdapter.isComplete()) {
|
||||
channelPromise.setFailure(new IOException("Channel unregistered before processing completed: " + ctx.channel().toString()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
|
||||
channelPromise.setFailure(cause);
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform a synchronous method call to the server. The server is expected to write
|
||||
* a byte stream response to the channel, which may be deserialized into a Java object
|
||||
@ -86,5 +98,8 @@ public class CacheClientRequestHandler extends ChannelInboundHandlerAdapter {
|
||||
channel.writeAndFlush(Unpooled.wrappedBuffer(outboundAdapter.toBytes()));
|
||||
channelPromise.awaitUninterruptibly();
|
||||
this.inboundAdapter = new NullInboundAdapter();
|
||||
if (channelPromise.cause() != null) {
|
||||
throw new IOException("Request invocation failed", channelPromise.cause());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ public class DistributedMapCacheClientService extends AbstractControllerService
|
||||
.build();
|
||||
|
||||
/**
|
||||
* The implementation of the business logic for {@link DistributedSetCacheClientService}.
|
||||
* The implementation of the business logic for {@link DistributedMapCacheClientService}.
|
||||
*/
|
||||
private volatile NettyDistributedMapCacheClient cacheClient = null;
|
||||
|
||||
|
@ -32,7 +32,7 @@ public class NullInboundAdapter implements InboundAdapter {
|
||||
|
||||
@Override
|
||||
public boolean isComplete() {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -18,6 +18,7 @@ package org.apache.nifi.distributed.cache.server;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.DataInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
@ -49,16 +50,18 @@ public abstract class AbstractCacheServer implements CacheServer {
|
||||
|
||||
private final String identifier;
|
||||
private final int port;
|
||||
private final int maxReadSize;
|
||||
private final SSLContext sslContext;
|
||||
protected volatile boolean stopped = false;
|
||||
private final Set<Thread> processInputThreads = new CopyOnWriteArraySet<>();
|
||||
|
||||
private volatile ServerSocketChannel serverSocketChannel;
|
||||
|
||||
public AbstractCacheServer(final String identifier, final SSLContext sslContext, final int port) {
|
||||
public AbstractCacheServer(final String identifier, final SSLContext sslContext, final int port, final int maxReadSize) {
|
||||
this.identifier = identifier;
|
||||
this.port = port;
|
||||
this.sslContext = sslContext;
|
||||
this.maxReadSize = maxReadSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -108,14 +111,14 @@ public abstract class AbstractCacheServer implements CacheServer {
|
||||
rawInputStream = new SSLSocketChannelInputStream(sslSocketChannel);
|
||||
rawOutputStream = new SSLSocketChannelOutputStream(sslSocketChannel);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
} catch (final IOException e) {
|
||||
logger.error("Cannot create input and/or output streams for {}", new Object[]{identifier}, e);
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.error("", e);
|
||||
}
|
||||
try {
|
||||
socketChannel.close();
|
||||
} catch (IOException swallow) {
|
||||
} catch (final IOException swallow) {
|
||||
}
|
||||
|
||||
return;
|
||||
@ -179,19 +182,19 @@ public abstract class AbstractCacheServer implements CacheServer {
|
||||
if (serverSocketChannel != null && serverSocketChannel.isOpen()) {
|
||||
try {
|
||||
serverSocketChannel.close();
|
||||
} catch (IOException e) {
|
||||
} catch (final IOException e) {
|
||||
logger.warn("Server Socket Close Failed", e);
|
||||
}
|
||||
}
|
||||
// need to close out the created SocketChannels...this is done by interrupting
|
||||
// the created threads that loop on listen().
|
||||
for (Thread processInputThread : processInputThreads) {
|
||||
for (final Thread processInputThread : processInputThreads) {
|
||||
processInputThread.interrupt();
|
||||
int i = 0;
|
||||
while (!processInputThread.isInterrupted() && i++ < 5) {
|
||||
try {
|
||||
Thread.sleep(50); // allow thread to gracefully terminate
|
||||
} catch (InterruptedException e) {
|
||||
} catch (final InterruptedException e) {
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -213,4 +216,32 @@ public abstract class AbstractCacheServer implements CacheServer {
|
||||
* @throws IOException ex
|
||||
*/
|
||||
protected abstract boolean listen(InputStream in, OutputStream out, int version) throws IOException;
|
||||
|
||||
/**
|
||||
* Read a length-prefixed value from the {@link DataInputStream}.
|
||||
*
|
||||
* @param dis the {@link DataInputStream} from which to read the value
|
||||
* @return the serialized representation of the value
|
||||
* @throws IOException on failure to read from the input stream
|
||||
*/
|
||||
protected byte[] readValue(final DataInputStream dis) throws IOException {
|
||||
final int numBytes = validateSize(dis.readInt());
|
||||
final byte[] buffer = new byte[numBytes];
|
||||
dis.readFully(buffer);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a size value received from the {@link DataInputStream} against the configured maximum.
|
||||
*
|
||||
* @param size the size value received from the {@link DataInputStream}
|
||||
* @return the size value, iff it passes validation; otherwise, an exception is thrown
|
||||
*/
|
||||
protected int validateSize(final int size) {
|
||||
if (size <= maxReadSize) {
|
||||
return size;
|
||||
} else {
|
||||
throw new IllegalStateException(String.format("Size [%d] exceeds maximum configured read [%d]", size, maxReadSize));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -68,6 +68,14 @@ public abstract class DistributedCacheServer extends AbstractControllerService {
|
||||
.required(false)
|
||||
.addValidator(StandardValidators.createDirectoryExistsValidator(true, true))
|
||||
.build();
|
||||
public static final PropertyDescriptor MAX_READ_SIZE = new PropertyDescriptor.Builder()
|
||||
.name("maximum-read-size")
|
||||
.displayName("Maximum Read Size")
|
||||
.description("The maximum number of network bytes to read for a single cache item")
|
||||
.required(false)
|
||||
.addValidator(StandardValidators.DATA_SIZE_VALIDATOR)
|
||||
.defaultValue("1 MB")
|
||||
.build();
|
||||
|
||||
private volatile CacheServer cacheServer;
|
||||
|
||||
@ -79,6 +87,7 @@ public abstract class DistributedCacheServer extends AbstractControllerService {
|
||||
properties.add(EVICTION_POLICY);
|
||||
properties.add(PERSISTENCE_PATH);
|
||||
properties.add(SSL_CONTEXT_SERVICE);
|
||||
properties.add(MAX_READ_SIZE);
|
||||
return properties;
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ import javax.net.ssl.SSLContext;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.controller.ConfigurationContext;
|
||||
import org.apache.nifi.processor.DataUnit;
|
||||
import org.apache.nifi.ssl.SSLContextService;
|
||||
|
||||
@Tags({"distributed", "set", "distinct", "cache", "server"})
|
||||
@ -35,6 +36,7 @@ public class DistributedSetCacheServer extends DistributedCacheServer {
|
||||
final SSLContextService sslContextService = context.getProperty(SSL_CONTEXT_SERVICE).asControllerService(SSLContextService.class);
|
||||
final int maxSize = context.getProperty(MAX_CACHE_ENTRIES).asInteger();
|
||||
final String evictionPolicyName = context.getProperty(EVICTION_POLICY).getValue();
|
||||
final int maxReadSize = context.getProperty(MAX_READ_SIZE).asDataSize(DataUnit.B).intValue();
|
||||
|
||||
final SSLContext sslContext;
|
||||
if (sslContextService == null) {
|
||||
@ -61,7 +63,7 @@ public class DistributedSetCacheServer extends DistributedCacheServer {
|
||||
try {
|
||||
final File persistenceDir = persistencePath == null ? null : new File(persistencePath);
|
||||
|
||||
return new SetCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir);
|
||||
return new SetCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir, maxReadSize);
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
@ -36,8 +36,8 @@ public class SetCacheServer extends AbstractCacheServer {
|
||||
private final SetCache cache;
|
||||
|
||||
public SetCacheServer(final String identifier, final SSLContext sslContext, final int port, final int maxSize,
|
||||
final EvictionPolicy evictionPolicy, final File persistencePath) throws IOException {
|
||||
super(identifier, sslContext, port);
|
||||
final EvictionPolicy evictionPolicy, final File persistencePath, final int maxReadSize) throws IOException {
|
||||
super(identifier, sslContext, port, maxReadSize);
|
||||
|
||||
final SetCache simpleCache = new SimpleSetCache(identifier, maxSize, evictionPolicy);
|
||||
|
||||
@ -60,9 +60,7 @@ public class SetCacheServer extends AbstractCacheServer {
|
||||
return false;
|
||||
}
|
||||
|
||||
final int valueLength = dis.readInt();
|
||||
final byte[] value = new byte[valueLength];
|
||||
dis.readFully(value);
|
||||
final byte[] value = readValue(dis);
|
||||
final ByteBuffer valueBuffer = ByteBuffer.wrap(value);
|
||||
|
||||
final SetCacheResult response;
|
||||
@ -101,5 +99,4 @@ public class SetCacheServer extends AbstractCacheServer {
|
||||
stop();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ import org.apache.nifi.controller.ConfigurationContext;
|
||||
import org.apache.nifi.distributed.cache.server.CacheServer;
|
||||
import org.apache.nifi.distributed.cache.server.DistributedCacheServer;
|
||||
import org.apache.nifi.distributed.cache.server.EvictionPolicy;
|
||||
import org.apache.nifi.processor.DataUnit;
|
||||
import org.apache.nifi.ssl.SSLContextService;
|
||||
|
||||
@Tags({"distributed", "cluster", "map", "cache", "server", "key/value"})
|
||||
@ -41,6 +42,7 @@ public class DistributedMapCacheServer extends DistributedCacheServer {
|
||||
final SSLContextService sslContextService = context.getProperty(SSL_CONTEXT_SERVICE).asControllerService(SSLContextService.class);
|
||||
final int maxSize = context.getProperty(MAX_CACHE_ENTRIES).asInteger();
|
||||
final String evictionPolicyName = context.getProperty(EVICTION_POLICY).getValue();
|
||||
final int maxReadSize = context.getProperty(MAX_READ_SIZE).asDataSize(DataUnit.B).intValue();
|
||||
|
||||
final SSLContext sslContext;
|
||||
if (sslContextService == null) {
|
||||
@ -67,14 +69,16 @@ public class DistributedMapCacheServer extends DistributedCacheServer {
|
||||
try {
|
||||
final File persistenceDir = persistencePath == null ? null : new File(persistencePath);
|
||||
|
||||
return createMapCacheServer(port, maxSize, sslContext, evictionPolicy, persistenceDir);
|
||||
return createMapCacheServer(port, maxSize, sslContext, evictionPolicy, persistenceDir, maxReadSize);
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
protected MapCacheServer createMapCacheServer(int port, int maxSize, SSLContext sslContext, EvictionPolicy evictionPolicy, File persistenceDir) throws IOException {
|
||||
return new MapCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir);
|
||||
protected MapCacheServer createMapCacheServer(
|
||||
final int port, final int maxSize, final SSLContext sslContext, final EvictionPolicy evictionPolicy,
|
||||
final File persistenceDir, final int maxReadSize) throws IOException {
|
||||
return new MapCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir, maxReadSize);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -38,8 +38,8 @@ public class MapCacheServer extends AbstractCacheServer {
|
||||
private final MapCache cache;
|
||||
|
||||
public MapCacheServer(final String identifier, final SSLContext sslContext, final int port, final int maxSize,
|
||||
final EvictionPolicy evictionPolicy, final File persistencePath) throws IOException {
|
||||
super(identifier, sslContext, port);
|
||||
final EvictionPolicy evictionPolicy, final File persistencePath, final int maxReadSize) throws IOException {
|
||||
super(identifier, sslContext, port, maxReadSize);
|
||||
|
||||
final MapCache simpleCache = new SimpleMapCache(identifier, maxSize, evictionPolicy);
|
||||
|
||||
@ -123,7 +123,7 @@ public class MapCacheServer extends AbstractCacheServer {
|
||||
break;
|
||||
}
|
||||
case "subMap": {
|
||||
final int numKeys = dis.readInt();
|
||||
final int numKeys = validateSize(dis.readInt());
|
||||
for(int i=0;i<numKeys;i++) {
|
||||
final byte[] key = readValue(dis);
|
||||
final ByteBuffer existingValue = cache.get(ByteBuffer.wrap(key));
|
||||
@ -249,12 +249,4 @@ public class MapCacheServer extends AbstractCacheServer {
|
||||
stop();
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] readValue(final DataInputStream dis) throws IOException {
|
||||
final int numBytes = dis.readInt();
|
||||
final byte[] buffer = new byte[numBytes];
|
||||
dis.readFully(buffer);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ package org.apache.nifi.distributed.cache.server;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
@ -33,6 +34,7 @@ import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.commons.lang3.SerializationException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.nifi.components.PropertyDescriptor;
|
||||
import org.apache.nifi.distributed.cache.client.AtomicCacheEntry;
|
||||
import org.apache.nifi.distributed.cache.client.Deserializer;
|
||||
@ -42,12 +44,14 @@ import org.apache.nifi.distributed.cache.client.Serializer;
|
||||
import org.apache.nifi.distributed.cache.client.exception.DeserializationException;
|
||||
import org.apache.nifi.distributed.cache.server.map.DistributedMapCacheServer;
|
||||
import org.apache.nifi.distributed.cache.server.map.MapCacheServer;
|
||||
import org.apache.nifi.processor.DataUnit;
|
||||
import org.apache.nifi.processor.Processor;
|
||||
import org.apache.nifi.processor.util.StandardValidators;
|
||||
import org.apache.nifi.remote.StandardVersionNegotiator;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
import org.apache.nifi.util.MockConfigurationContext;
|
||||
import org.apache.nifi.util.MockControllerServiceInitializationContext;
|
||||
import org.apache.nifi.util.MockPropertyValue;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.junit.Test;
|
||||
@ -597,8 +601,8 @@ public class TestServerAndClient {
|
||||
// Create a server that only supports protocol version 1.
|
||||
final DistributedMapCacheServer server = new MapServer() {
|
||||
@Override
|
||||
protected MapCacheServer createMapCacheServer(int port, int maxSize, SSLContext sslContext, EvictionPolicy evictionPolicy, File persistenceDir) throws IOException {
|
||||
return new MapCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir) {
|
||||
protected MapCacheServer createMapCacheServer(int port, int maxSize, SSLContext sslContext, EvictionPolicy evictionPolicy, File persistenceDir, int maxReadSize) throws IOException {
|
||||
return new MapCacheServer(getIdentifier(), sslContext, port, maxSize, evictionPolicy, persistenceDir, maxReadSize) {
|
||||
@Override
|
||||
protected StandardVersionNegotiator getVersionNegotiator() {
|
||||
return new StandardVersionNegotiator(1);
|
||||
@ -666,6 +670,45 @@ public class TestServerAndClient {
|
||||
server.shutdownServer();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLimitServiceReadSizeMap() throws InitializationException, IOException {
|
||||
final TestRunner runner = TestRunners.newTestRunner(Mockito.mock(Processor.class));
|
||||
final DistributedMapCacheServer server = new MapServer();
|
||||
runner.addControllerService("server", server);
|
||||
runner.enableControllerService(server);
|
||||
|
||||
final DistributedMapCacheClientService client = createMapClient(server.getPort());
|
||||
final Serializer<String> serializer = new StringSerializer();
|
||||
|
||||
final String key = "key";
|
||||
final int maxReadSize = new MockPropertyValue(DistributedCacheServer.MAX_READ_SIZE.getDefaultValue()).asDataSize(DataUnit.B).intValue();
|
||||
final int belowThreshold = maxReadSize / key.length();
|
||||
final int aboveThreshold = belowThreshold + 1;
|
||||
final String keyBelowThreshold = StringUtils.repeat(key, belowThreshold);
|
||||
final String keyAboveThreshold = StringUtils.repeat(key, aboveThreshold);
|
||||
assertFalse(client.containsKey(keyBelowThreshold, serializer));
|
||||
assertThrows(IOException.class, () -> client.containsKey(keyAboveThreshold, serializer));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLimitServiceReadSizeSet() throws InitializationException, IOException {
|
||||
final TestRunner runner = TestRunners.newTestRunner(Mockito.mock(Processor.class));
|
||||
final DistributedSetCacheServer server = new SetServer();
|
||||
runner.addControllerService("server", server);
|
||||
runner.enableControllerService(server);
|
||||
|
||||
final DistributedSetCacheClientService client = createClient(server.getPort());
|
||||
final Serializer<String> serializer = new StringSerializer();
|
||||
|
||||
final String value = "value";
|
||||
final int maxReadSize = new MockPropertyValue(DistributedCacheServer.MAX_READ_SIZE.getDefaultValue()).asDataSize(DataUnit.B).intValue();
|
||||
final int belowThreshold = maxReadSize / value.length();
|
||||
final int aboveThreshold = belowThreshold + 1;
|
||||
final String valueBelowThreshold = StringUtils.repeat(value, belowThreshold);
|
||||
final String valueAboveThreshold = StringUtils.repeat(value, aboveThreshold);
|
||||
assertFalse(client.contains(valueBelowThreshold, serializer));
|
||||
assertThrows(IOException.class, () -> client.contains(valueAboveThreshold, serializer));
|
||||
}
|
||||
|
||||
private void waitABit() {
|
||||
try {
|
||||
|
Loading…
x
Reference in New Issue
Block a user