diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/AbstractPutEventProcessor.java b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/AbstractPutEventProcessor.java index 5833819ea8..a246272c90 100644 --- a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/AbstractPutEventProcessor.java +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/AbstractPutEventProcessor.java @@ -360,10 +360,12 @@ public abstract class AbstractPutEventProcessor extends AbstractSessionFactoryPr boolean returned = senderPool.offer(sender); // if the pool is full then close the sender. if (!returned) { + getLogger().debug("Sender wasn't returned because queue was full, closing sender"); sender.close(); } } else { // probably already closed here, but quietly close anyway to be safe. + getLogger().debug("Sender is not connected, closing sender"); sender.close(); } } diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SocketChannelSender.java b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SocketChannelSender.java index 6f7796b21a..a8ad2e4dce 100644 --- a/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SocketChannelSender.java +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-processor-utils/src/main/java/org/apache/nifi/processor/util/put/sender/SocketChannelSender.java @@ -98,7 +98,37 @@ public class SocketChannelSender extends ChannelSender { } public OutputStream getOutputStream() { - return socketChannelOutput; + return new OutputStream() { + @Override + public void write(int b) throws IOException { + socketChannelOutput.write(b); + } + + @Override + public void write(byte[] b) throws IOException { + socketChannelOutput.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + socketChannelOutput.write(b, off, len); + } + + @Override + public void close() throws IOException { + socketChannelOutput.close(); + } + + @Override + public void flush() throws IOException { + socketChannelOutput.flush(); + updateLastUsed(); + } + }; + } + + private void updateLastUsed() { + this.lastUsed = System.currentTimeMillis(); } } diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/pom.xml b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/pom.xml index a315382394..3ed4ac09c9 100644 --- a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/pom.xml +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/pom.xml @@ -32,11 +32,19 @@ org.apache.nifi nifi-utils + + org.apache.nifi + nifi-security-utils + org.apache.nifi nifi-schema-registry-service-api + + org.apache.nifi + nifi-record-serialization-service-api + org.apache.nifi nifi-record diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/IOUtils.java b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/IOUtils.java new file mode 100644 index 0000000000..43bbd18112 --- /dev/null +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/IOUtils.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.record.listen; + +import java.io.Closeable; +import java.io.IOException; + +public class IOUtils { + + public static void closeQuietly(final Closeable closeable) { + if (closeable == null) { + return; + } + try { + + closeable.close(); + } catch (final IOException ioe) { + + } + } + +} diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/SSLSocketChannelRecordReader.java b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/SSLSocketChannelRecordReader.java new file mode 100644 index 0000000000..873297e497 --- /dev/null +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/SSLSocketChannelRecordReader.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.record.listen; + +import org.apache.nifi.flowfile.FlowFile; +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannel; +import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannelInputStream; +import org.apache.nifi.schema.access.SchemaNotFoundException; +import org.apache.nifi.serialization.MalformedRecordException; +import org.apache.nifi.serialization.RecordReader; +import org.apache.nifi.serialization.RecordReaderFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.net.InetAddress; +import java.nio.channels.SocketChannel; + +/** + * Encapsulates an SSLSocketChannel and a RecordReader created for the given channel. + */ +public class SSLSocketChannelRecordReader implements SocketChannelRecordReader { + + private final SocketChannel socketChannel; + private final SSLSocketChannel sslSocketChannel; + private final RecordReaderFactory readerFactory; + private final SocketChannelRecordReaderDispatcher dispatcher; + + private RecordReader recordReader; + + public SSLSocketChannelRecordReader(final SocketChannel socketChannel, + final SSLSocketChannel sslSocketChannel, + final RecordReaderFactory readerFactory, + final SocketChannelRecordReaderDispatcher dispatcher) { + this.socketChannel = socketChannel; + this.sslSocketChannel = sslSocketChannel; + this.readerFactory = readerFactory; + this.dispatcher = dispatcher; + } + + @Override + public RecordReader createRecordReader(final FlowFile flowFile, final ComponentLog logger) throws IOException, MalformedRecordException, SchemaNotFoundException { + if (recordReader != null) { + throw new IllegalStateException("Cannot create RecordReader because already created"); + } + + final InputStream in = new SSLSocketChannelInputStream(sslSocketChannel); + recordReader = readerFactory.createRecordReader(flowFile, in, logger); + return recordReader; + } + + @Override + public RecordReader getRecordReader() { + return recordReader; + } + + @Override + public InetAddress getRemoteAddress() { + return socketChannel.socket().getInetAddress(); + } + + @Override + public boolean isClosed() { + return sslSocketChannel.isClosed(); + } + + @Override + public void close() { + IOUtils.closeQuietly(recordReader); + IOUtils.closeQuietly(sslSocketChannel); + dispatcher.connectionCompleted(); + } + +} diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/SocketChannelRecordReader.java b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/SocketChannelRecordReader.java new file mode 100644 index 0000000000..b648b7753d --- /dev/null +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/SocketChannelRecordReader.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.record.listen; + +import org.apache.nifi.flowfile.FlowFile; +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.schema.access.SchemaNotFoundException; +import org.apache.nifi.serialization.MalformedRecordException; +import org.apache.nifi.serialization.RecordReader; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; + +/** + * Encapsulates a SocketChannel and a RecordReader for the channel. + */ +public interface SocketChannelRecordReader extends Closeable { + + /** + * Currently a RecordReader can only be created with a FlowFile. Since we won't have a FlowFile at the time + * a connection is accepted, this method will be used to lazily create the RecordReader later. Eventually this + * method should be removed and the reader should be passed in through the constructor. + * + * + * @param flowFile the flow file we are creating the reader for + * @param logger the logger of the component creating the reader + * @return a RecordReader + * + * @throws IllegalStateException if create is called after a reader has already been created + */ + RecordReader createRecordReader(final FlowFile flowFile, final ComponentLog logger) throws IOException, MalformedRecordException, SchemaNotFoundException; + + /** + * @return the RecordReader created by calling createRecordReader, or null if one has not been created yet + */ + RecordReader getRecordReader(); + + /** + * @return the remote address of the underlying channel + */ + InetAddress getRemoteAddress(); + + /** + * @return true if the underlying channel is closed, false otherwise + */ + boolean isClosed(); + +} diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/SocketChannelRecordReaderDispatcher.java b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/SocketChannelRecordReaderDispatcher.java new file mode 100644 index 0000000000..a72b0d8df1 --- /dev/null +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/SocketChannelRecordReaderDispatcher.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.record.listen; + +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannel; +import org.apache.nifi.security.util.SslContextFactory; +import org.apache.nifi.serialization.RecordReaderFactory; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import java.io.Closeable; +import java.net.SocketAddress; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Accepts connections on the given ServerSocketChannel and dispatches a SocketChannelRecordReader for processing. + */ +public class SocketChannelRecordReaderDispatcher implements Runnable, Closeable { + + private final ServerSocketChannel serverSocketChannel; + private final SSLContext sslContext; + private final SslContextFactory.ClientAuth clientAuth; + private final int socketReadTimeout; + private final int receiveBufferSize; + private final int maxConnections; + private final RecordReaderFactory readerFactory; + private final BlockingQueue recordReaders; + private final ComponentLog logger; + + private final AtomicInteger currentConnections = new AtomicInteger(0); + + private volatile boolean stopped = false; + + public SocketChannelRecordReaderDispatcher(final ServerSocketChannel serverSocketChannel, + final SSLContext sslContext, + final SslContextFactory.ClientAuth clientAuth, + final int socketReadTimeout, + final int receiveBufferSize, + final int maxConnections, + final RecordReaderFactory readerFactory, + final BlockingQueue recordReaders, + final ComponentLog logger) { + this.serverSocketChannel = serverSocketChannel; + this.sslContext = sslContext; + this.clientAuth = clientAuth; + this.socketReadTimeout = socketReadTimeout; + this.receiveBufferSize = receiveBufferSize; + this.maxConnections = maxConnections; + this.readerFactory = readerFactory; + this.recordReaders = recordReaders; + this.logger = logger; + } + + @Override + public void run() { + while(!stopped) { + try { + final SocketChannel socketChannel = serverSocketChannel.accept(); + if (socketChannel == null) { + Thread.sleep(20); + continue; + } + + final SocketAddress remoteSocketAddress = socketChannel.getRemoteAddress(); + socketChannel.socket().setSoTimeout(socketReadTimeout); + socketChannel.socket().setReceiveBufferSize(receiveBufferSize); + + if (currentConnections.incrementAndGet() > maxConnections){ + currentConnections.decrementAndGet(); + final String remoteAddress = remoteSocketAddress == null ? "null" : remoteSocketAddress.toString(); + logger.warn("Rejecting connection from {} because max connections has been met", new Object[]{remoteAddress}); + IOUtils.closeQuietly(socketChannel); + continue; + } + + if (logger.isDebugEnabled()) { + final String remoteAddress = remoteSocketAddress == null ? "null" : remoteSocketAddress.toString(); + logger.debug("Accepted connection from {}", new Object[]{remoteAddress}); + } + + // create a StandardSocketChannelRecordReader or an SSLSocketChannelRecordReader based on presence of SSLContext + final SocketChannelRecordReader socketChannelRecordReader; + if (sslContext == null) { + socketChannelRecordReader = new StandardSocketChannelRecordReader(socketChannel, readerFactory, this); + } else { + final SSLEngine sslEngine = sslContext.createSSLEngine(); + sslEngine.setUseClientMode(false); + + switch (clientAuth) { + case REQUIRED: + sslEngine.setNeedClientAuth(true); + break; + case WANT: + sslEngine.setWantClientAuth(true); + break; + case NONE: + sslEngine.setNeedClientAuth(false); + sslEngine.setWantClientAuth(false); + break; + } + + final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslEngine, socketChannel); + socketChannelRecordReader = new SSLSocketChannelRecordReader(socketChannel, sslSocketChannel, readerFactory, this); + } + + // queue the SocketChannelRecordReader for processing by the processor + recordReaders.offer(socketChannelRecordReader); + + } catch (Exception e) { + logger.error("Error dispatching connection: " + e.getMessage(), e); + } + } + } + + public int getPort() { + return serverSocketChannel == null ? 0 : serverSocketChannel.socket().getLocalPort(); + } + + @Override + public void close() { + this.stopped = true; + IOUtils.closeQuietly(this.serverSocketChannel); + } + + public void connectionCompleted() { + currentConnections.decrementAndGet(); + } + +} diff --git a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/StandardSocketChannelRecordReader.java b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/StandardSocketChannelRecordReader.java new file mode 100644 index 0000000000..1e220442c0 --- /dev/null +++ b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/main/java/org/apache/nifi/record/listen/StandardSocketChannelRecordReader.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.record.listen; + +import org.apache.nifi.flowfile.FlowFile; +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.schema.access.SchemaNotFoundException; +import org.apache.nifi.serialization.MalformedRecordException; +import org.apache.nifi.serialization.RecordReader; +import org.apache.nifi.serialization.RecordReaderFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.net.InetAddress; +import java.nio.channels.SocketChannel; + +/** + * Encapsulates a SocketChannel and a RecordReader created for the given channel. + */ +public class StandardSocketChannelRecordReader implements SocketChannelRecordReader { + + private final SocketChannel socketChannel; + private final RecordReaderFactory readerFactory; + private final SocketChannelRecordReaderDispatcher dispatcher; + + private RecordReader recordReader; + + public StandardSocketChannelRecordReader(final SocketChannel socketChannel, + final RecordReaderFactory readerFactory, + final SocketChannelRecordReaderDispatcher dispatcher) { + this.socketChannel = socketChannel; + this.readerFactory = readerFactory; + this.dispatcher = dispatcher; + } + + @Override + public RecordReader createRecordReader(final FlowFile flowFile, final ComponentLog logger) throws IOException, MalformedRecordException, SchemaNotFoundException { + if (recordReader != null) { + throw new IllegalStateException("Cannot create RecordReader because already created"); + } + + final InputStream in = socketChannel.socket().getInputStream(); + recordReader = readerFactory.createRecordReader(flowFile, in, logger); + return recordReader; + } + + @Override + public RecordReader getRecordReader() { + return recordReader; + } + + @Override + public InetAddress getRemoteAddress() { + return socketChannel.socket().getInetAddress(); + } + + @Override + public boolean isClosed() { + return !socketChannel.isOpen(); + } + + @Override + public void close() { + IOUtils.closeQuietly(recordReader); + IOUtils.closeQuietly(socketChannel); + dispatcher.connectionCompleted(); + } +} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/ListenTCPRecord.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/ListenTCPRecord.java new file mode 100644 index 0000000000..2ad9ab57b2 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/ListenTCPRecord.java @@ -0,0 +1,463 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.standard; + +import org.apache.commons.io.IOUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.nifi.annotation.behavior.InputRequirement; +import org.apache.nifi.annotation.behavior.SupportsBatching; +import org.apache.nifi.annotation.behavior.WritesAttribute; +import org.apache.nifi.annotation.behavior.WritesAttributes; +import org.apache.nifi.annotation.documentation.CapabilityDescription; +import org.apache.nifi.annotation.documentation.Tags; +import org.apache.nifi.annotation.lifecycle.OnScheduled; +import org.apache.nifi.annotation.lifecycle.OnStopped; +import org.apache.nifi.components.AllowableValue; +import org.apache.nifi.components.PropertyDescriptor; +import org.apache.nifi.components.ValidationContext; +import org.apache.nifi.components.ValidationResult; +import org.apache.nifi.flowfile.FlowFile; +import org.apache.nifi.flowfile.attributes.CoreAttributes; +import org.apache.nifi.processor.AbstractProcessor; +import org.apache.nifi.processor.DataUnit; +import org.apache.nifi.processor.ProcessContext; +import org.apache.nifi.processor.ProcessSession; +import org.apache.nifi.processor.Relationship; +import org.apache.nifi.processor.exception.ProcessException; +import org.apache.nifi.processor.util.StandardValidators; +import org.apache.nifi.processor.util.listen.ListenerProperties; +import org.apache.nifi.record.listen.SocketChannelRecordReader; +import org.apache.nifi.record.listen.SocketChannelRecordReaderDispatcher; +import org.apache.nifi.security.util.SslContextFactory; +import org.apache.nifi.serialization.RecordReader; +import org.apache.nifi.serialization.RecordReaderFactory; +import org.apache.nifi.serialization.RecordSetWriter; +import org.apache.nifi.serialization.RecordSetWriterFactory; +import org.apache.nifi.serialization.WriteResult; +import org.apache.nifi.serialization.record.Record; +import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.ssl.SSLContextService; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.NetworkInterface; +import java.net.SocketTimeoutException; +import java.nio.channels.ServerSocketChannel; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.apache.nifi.processor.util.listen.ListenerProperties.NETWORK_INTF_NAME; + +@SupportsBatching +@InputRequirement(InputRequirement.Requirement.INPUT_FORBIDDEN) +@Tags({"listen", "tcp", "record", "tls", "ssl"}) +@CapabilityDescription("Listens for incoming TCP connections and reads data from each connection using a configured record " + + "reader, and writes the records to a flow file using a configured record writer. The type of record reader selected will " + + "determine how clients are expected to send data. For example, when using a Grok reader to read logs, a client can keep an " + + "open connection and continuously stream data, but when using an JSON reader, the client cannot send an array of JSON " + + "documents and then send another array on the same connection, as the reader would be in a bad state at that point. Records " + + "will be read from the connection in blocking mode, and will timeout according to the Read Timeout specified in the processor. " + + "If the read times out, or if any other error is encountered when reading, the connection will be closed, and any records " + + "read up to that point will be handled according to the configured Read Error Strategy (Discard or Transfer). In cases where " + + "clients are keeping a connection open, the concurrent tasks for the processor should be adjusted to match the Max Number of " + + "TCP Connections allowed, so that there is a task processing each connection.") +@WritesAttributes({ + @WritesAttribute(attribute="tcp.sender", description="The host that sent the data."), + @WritesAttribute(attribute="tcp.port", description="The port that the processor accepted the connection on."), + @WritesAttribute(attribute="record.count", description="The number of records written to the flow file."), + @WritesAttribute(attribute="mime.type", description="The mime-type of the writer used to write the records to the flow file.") +}) +public class ListenTCPRecord extends AbstractProcessor { + + static final PropertyDescriptor PORT = new PropertyDescriptor.Builder() + .name("port") + .displayName("Port") + .description("The port to listen on for communication.") + .required(true) + .addValidator(StandardValidators.PORT_VALIDATOR) + .expressionLanguageSupported(true) + .build(); + + static final PropertyDescriptor READ_TIMEOUT = new PropertyDescriptor.Builder() + .name("read-timeout") + .displayName("Read Timeout") + .description("The amount of time to wait before timing out when reading from a connection.") + .addValidator(StandardValidators.TIME_PERIOD_VALIDATOR) + .defaultValue("10 seconds") + .required(true) + .build(); + + static final PropertyDescriptor MAX_SOCKET_BUFFER_SIZE = new PropertyDescriptor.Builder() + .name("max-size-socket-buffer") + .displayName("Max Size of Socket Buffer") + .description("The maximum size of the socket buffer that should be used. This is a suggestion to the Operating System " + + "to indicate how big the socket buffer should be. If this value is set too low, the buffer may fill up before " + + "the data can be read, and incoming data will be dropped.") + .addValidator(StandardValidators.DATA_SIZE_VALIDATOR) + .defaultValue("1 MB") + .required(true) + .build(); + + static final PropertyDescriptor MAX_CONNECTIONS = new PropertyDescriptor.Builder() + .name("max-number-tcp-connections") + .displayName("Max Number of TCP Connections") + .description("The maximum number of concurrent TCP connections to accept. In cases where clients are keeping a connection open, " + + "the concurrent tasks for the processor should be adjusted to match the Max Number of TCP Connections allowed, so that there " + + "is a task processing each connection.") + .addValidator(StandardValidators.createLongValidator(1, 65535, true)) + .defaultValue("2") + .required(true) + .build(); + + static final PropertyDescriptor RECORD_READER = new PropertyDescriptor.Builder() + .name("record-reader") + .displayName("Record Reader") + .description("The Record Reader to use for incoming FlowFiles") + .identifiesControllerService(RecordReaderFactory.class) + .expressionLanguageSupported(false) + .required(true) + .build(); + + static final PropertyDescriptor RECORD_WRITER = new PropertyDescriptor.Builder() + .name("record-writer") + .displayName("Record Writer") + .description("The Record Writer to use in order to serialize the data before writing to a FlowFile") + .identifiesControllerService(RecordSetWriterFactory.class) + .expressionLanguageSupported(false) + .required(true) + .build(); + + static final AllowableValue ERROR_HANDLING_DISCARD = new AllowableValue("Discard", "Discard", "Discards any records already received and closes the connection."); + static final AllowableValue ERROR_HANDLING_TRANSFER = new AllowableValue("Transfer", "Transfer", "Transfers any records already received and closes the connection."); + + static final PropertyDescriptor READER_ERROR_HANDLING_STRATEGY = new PropertyDescriptor.Builder() + .name("reader-error-handling-strategy") + .displayName("Read Error Strategy") + .description("Indicates how to deal with an error while reading the next record from a connection, when previous records have already been read from the connection.") + .required(true) + .allowableValues(ERROR_HANDLING_TRANSFER, ERROR_HANDLING_DISCARD) + .defaultValue(ERROR_HANDLING_TRANSFER.getValue()) + .build(); + + static final PropertyDescriptor RECORD_BATCH_SIZE = new PropertyDescriptor.Builder() + .name("record-batch-size") + .displayName("Record Batch Size") + .description("The maximum number of records to write to a single FlowFile.") + .addValidator(StandardValidators.POSITIVE_INTEGER_VALIDATOR) + .expressionLanguageSupported(false) + .defaultValue("1000") + .required(true) + .build(); + + static final PropertyDescriptor SSL_CONTEXT_SERVICE = new PropertyDescriptor.Builder() + .name("ssl-context-service") + .displayName("SSL Context Service") + .description("The Controller Service to use in order to obtain an SSL Context. If this property is set, " + + "messages will be received over a secure connection.") + .required(false) + .identifiesControllerService(SSLContextService.class) + .build(); + + static final PropertyDescriptor CLIENT_AUTH = new PropertyDescriptor.Builder() + .name("client-auth") + .displayName("Client Auth") + .description("The client authentication policy to use for the SSL Context. Only used if an SSL Context Service is provided.") + .required(false) + .allowableValues(SSLContextService.ClientAuth.values()) + .defaultValue(SSLContextService.ClientAuth.REQUIRED.name()) + .build(); + + static final Relationship REL_SUCCESS = new Relationship.Builder() + .name("success") + .description("Messages received successfully will be sent out this relationship.") + .build(); + + + static final List PROPERTIES; + static { + final List props = new ArrayList<>(); + props.add(ListenerProperties.NETWORK_INTF_NAME); + props.add(PORT); + props.add(MAX_SOCKET_BUFFER_SIZE); + props.add(MAX_CONNECTIONS); + props.add(READ_TIMEOUT); + props.add(RECORD_READER); + props.add(RECORD_WRITER); + props.add(READER_ERROR_HANDLING_STRATEGY); + props.add(RECORD_BATCH_SIZE); + props.add(SSL_CONTEXT_SERVICE); + props.add(CLIENT_AUTH); + PROPERTIES = Collections.unmodifiableList(props); + } + + static final Set RELATIONSHIPS; + static { + final Set rels = new HashSet<>(); + rels.add(REL_SUCCESS); + RELATIONSHIPS = Collections.unmodifiableSet(rels); + } + + static final int POLL_TIMEOUT_MS = 20; + + private volatile int port; + private volatile SocketChannelRecordReaderDispatcher dispatcher; + private volatile BlockingQueue socketReaders = new LinkedBlockingQueue<>(); + + @Override + public Set getRelationships() { + return RELATIONSHIPS; + } + + @Override + protected List getSupportedPropertyDescriptors() { + return PROPERTIES; + } + + @Override + protected Collection customValidate(final ValidationContext validationContext) { + final List results = new ArrayList<>(); + + final String clientAuth = validationContext.getProperty(CLIENT_AUTH).getValue(); + final SSLContextService sslContextService = validationContext.getProperty(SSL_CONTEXT_SERVICE).asControllerService(SSLContextService.class); + + if (sslContextService != null && StringUtils.isBlank(clientAuth)) { + results.add(new ValidationResult.Builder() + .explanation("Client Auth must be provided when using TLS/SSL") + .valid(false).subject("Client Auth").build()); + } + + return results; + } + + @OnScheduled + public void onScheduled(final ProcessContext context) throws IOException { + this.port = context.getProperty(PORT).evaluateAttributeExpressions().asInteger(); + + final int readTimeout = context.getProperty(READ_TIMEOUT).asTimePeriod(TimeUnit.MILLISECONDS).intValue(); + final int maxSocketBufferSize = context.getProperty(MAX_SOCKET_BUFFER_SIZE).asDataSize(DataUnit.B).intValue(); + final int maxConnections = context.getProperty(MAX_CONNECTIONS).asInteger(); + final RecordReaderFactory recordReaderFactory = context.getProperty(RECORD_READER).asControllerService(RecordReaderFactory.class); + + // if the Network Interface Property wasn't provided then a null InetAddress will indicate to bind to all interfaces + final InetAddress nicAddress; + final String nicAddressStr = context.getProperty(NETWORK_INTF_NAME).evaluateAttributeExpressions().getValue(); + if (!StringUtils.isEmpty(nicAddressStr)) { + NetworkInterface netIF = NetworkInterface.getByName(nicAddressStr); + nicAddress = netIF.getInetAddresses().nextElement(); + } else { + nicAddress = null; + } + + SSLContext sslContext = null; + SslContextFactory.ClientAuth clientAuth = null; + final SSLContextService sslContextService = context.getProperty(SSL_CONTEXT_SERVICE).asControllerService(SSLContextService.class); + if (sslContextService != null) { + final String clientAuthValue = context.getProperty(CLIENT_AUTH).getValue(); + sslContext = sslContextService.createSSLContext(SSLContextService.ClientAuth.valueOf(clientAuthValue)); + clientAuth = SslContextFactory.ClientAuth.valueOf(clientAuthValue); + } + + // create a ServerSocketChannel in non-blocking mode and bind to the given address and port + final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); + serverSocketChannel.configureBlocking(false); + serverSocketChannel.bind(new InetSocketAddress(nicAddress, port)); + + this.dispatcher = new SocketChannelRecordReaderDispatcher(serverSocketChannel, sslContext, clientAuth, readTimeout, + maxSocketBufferSize, maxConnections, recordReaderFactory, socketReaders, getLogger()); + + // start a thread to run the dispatcher + final Thread readerThread = new Thread(dispatcher); + readerThread.setName(getClass().getName() + " [" + getIdentifier() + "]"); + readerThread.setDaemon(true); + readerThread.start(); + } + + @OnStopped + public void onStopped() { + if (dispatcher != null) { + dispatcher.close(); + dispatcher = null; + } + + SocketChannelRecordReader socketRecordReader; + while ((socketRecordReader = socketReaders.poll()) != null) { + IOUtils.closeQuietly(socketRecordReader.getRecordReader()); + } + } + + @Override + public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { + final SocketChannelRecordReader socketRecordReader = pollForSocketRecordReader(); + if (socketRecordReader == null) { + return; + } + + if (socketRecordReader.isClosed()) { + getLogger().warn("Unable to read records from {}, socket already closed", new Object[] {getRemoteAddress(socketRecordReader)}); + IOUtils.closeQuietly(socketRecordReader); // still need to call close so the overall count is decremented + return; + } + + final int recordBatchSize = context.getProperty(RECORD_BATCH_SIZE).asInteger(); + final String readerErrorHandling = context.getProperty(READER_ERROR_HANDLING_STRATEGY).getValue(); + final RecordSetWriterFactory recordSetWriterFactory = context.getProperty(RECORD_WRITER).asControllerService(RecordSetWriterFactory.class); + + // synchronize to ensure there are no stale values in the underlying SocketChannel + synchronized (socketRecordReader) { + FlowFile flowFile = session.create(); + try { + // lazily creating the record reader here b/c we need a flow file, eventually shouldn't have to do this + RecordReader recordReader = socketRecordReader.getRecordReader(); + if (recordReader == null) { + recordReader = socketRecordReader.createRecordReader(flowFile, getLogger()); + } + + Record record; + try { + record = recordReader.nextRecord(); + } catch (final Exception e) { + boolean timeout = false; + + // some of the underlying record libraries wrap the real exception in RuntimeException, so check each + // throwable (starting with the current one) to see if its a SocketTimeoutException + Throwable cause = e; + while (cause != null) { + if (cause instanceof SocketTimeoutException) { + timeout = true; + break; + } + cause = cause.getCause(); + } + + if (timeout) { + getLogger().debug("Timeout reading records, will try again later", e); + socketReaders.offer(socketRecordReader); + session.remove(flowFile); + return; + } else { + throw e; + } + } + + if (record == null) { + getLogger().debug("No records available from {}, closing connection", new Object[]{getRemoteAddress(socketRecordReader)}); + IOUtils.closeQuietly(socketRecordReader); + session.remove(flowFile); + return; + } + + String mimeType = null; + WriteResult writeResult = null; + + final RecordSchema recordSchema = recordSetWriterFactory.getSchema(flowFile, record.getSchema()); + try (final OutputStream out = session.write(flowFile); + final RecordSetWriter recordWriter = recordSetWriterFactory.createWriter(getLogger(), recordSchema, flowFile, out)) { + + // start the record set and write the first record from above + recordWriter.beginRecordSet(); + writeResult = recordWriter.write(record); + + while (record != null && writeResult.getRecordCount() < recordBatchSize) { + // handle a read failure according to the strategy selected... + // if discarding then bounce to the outer catch block which will close the connection and remove the flow file + // if keeping then null out the record to break out of the loop, which will transfer what we have and close the connection + try { + record = recordReader.nextRecord(); + } catch (final SocketTimeoutException ste) { + getLogger().debug("Timeout reading records, will try again later", ste); + break; + } catch (final Exception e) { + if (ERROR_HANDLING_DISCARD.getValue().equals(readerErrorHandling)) { + throw e; + } else { + record = null; + } + } + + if (record != null) { + writeResult = recordWriter.write(record); + } + } + + writeResult = recordWriter.finishRecordSet(); + recordWriter.flush(); + mimeType = recordWriter.getMimeType(); + } + + // if we didn't write any records then we need to remove the flow file + if (writeResult.getRecordCount() <= 0) { + getLogger().debug("Removing flow file, no records were written"); + session.remove(flowFile); + } else { + final String sender = getRemoteAddress(socketRecordReader); + + final Map attributes = new HashMap<>(writeResult.getAttributes()); + attributes.put(CoreAttributes.MIME_TYPE.key(), mimeType); + attributes.put("tcp.sender", sender); + attributes.put("tcp.port", String.valueOf(port)); + attributes.put("record.count", String.valueOf(writeResult.getRecordCount())); + flowFile = session.putAllAttributes(flowFile, attributes); + + final String senderHost = sender.startsWith("/") && sender.length() > 1 ? sender.substring(1) : sender; + final String transitUri = new StringBuilder().append("tcp").append("://").append(senderHost).append(":").append(port).toString(); + session.getProvenanceReporter().receive(flowFile, transitUri); + + session.transfer(flowFile, REL_SUCCESS); + } + + getLogger().debug("Re-queuing connection for further processing..."); + socketReaders.offer(socketRecordReader); + + } catch (Exception e) { + getLogger().error("Error processing records: " + e.getMessage(), e); + IOUtils.closeQuietly(socketRecordReader); + session.remove(flowFile); + return; + } + } + } + + private SocketChannelRecordReader pollForSocketRecordReader() { + try { + return socketReaders.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + } + + private String getRemoteAddress(final SocketChannelRecordReader socketChannelRecordReader) { + return socketChannelRecordReader.getRemoteAddress() == null ? "null" : socketChannelRecordReader.getRemoteAddress().toString(); + } + + public final int getDispatcherPort() { + return dispatcher == null ? 0 : dispatcher.getPort(); + } + +} diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutTCP.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutTCP.java index 75165d76d5..ee3e645074 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutTCP.java +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/PutTCP.java @@ -214,8 +214,10 @@ public class PutTCP extends AbstractPutEventProcessor { getLogger().error("Exception while handling a process session, transferring {} to failure.", new Object[] { flowFile }, e); } finally { if (closeSender) { + getLogger().debug("Closing sender"); sender.close(); } else { + getLogger().debug("Relinquishing sender"); relinquishSender(sender); } } diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processor.Processor b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processor.Processor index aa8e931766..b8eb4a1afd 100644 --- a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processor.Processor +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/resources/META-INF/services/org.apache.nifi.processor.Processor @@ -55,6 +55,7 @@ org.apache.nifi.processors.standard.ListenHTTP org.apache.nifi.processors.standard.ListenRELP org.apache.nifi.processors.standard.ListenSyslog org.apache.nifi.processors.standard.ListenTCP +org.apache.nifi.processors.standard.ListenTCPRecord org.apache.nifi.processors.standard.ListenUDP org.apache.nifi.processors.standard.ListSFTP org.apache.nifi.processors.standard.LogAttribute diff --git a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestListenTCPRecord.java b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestListenTCPRecord.java new file mode 100644 index 0000000000..6174715655 --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestListenTCPRecord.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.standard; + +import org.apache.commons.io.IOUtils; +import org.apache.nifi.json.JsonTreeReader; +import org.apache.nifi.processor.ProcessContext; +import org.apache.nifi.processor.ProcessSessionFactory; +import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.schema.access.SchemaAccessUtils; +import org.apache.nifi.security.util.SslContextFactory; +import org.apache.nifi.serialization.RecordReaderFactory; +import org.apache.nifi.serialization.RecordSetWriterFactory; +import org.apache.nifi.serialization.record.MockRecordWriter; +import org.apache.nifi.ssl.SSLContextService; +import org.apache.nifi.ssl.StandardSSLContextService; +import org.apache.nifi.util.MockFlowFile; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLContext; +import java.io.Closeable; +import java.io.IOException; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.security.KeyManagementException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class TestListenTCPRecord { + + static final Logger LOGGER = LoggerFactory.getLogger(TestListenTCPRecord.class); + + static final String SCHEMA_TEXT = "{\n" + + " \"name\": \"syslogRecord\",\n" + + " \"namespace\": \"nifi\",\n" + + " \"type\": \"record\",\n" + + " \"fields\": [\n" + + " { \"name\": \"timestamp\", \"type\": \"string\" },\n" + + " { \"name\": \"logsource\", \"type\": \"string\" },\n" + + " { \"name\": \"message\", \"type\": \"string\" }\n" + + " ]\n" + + "}"; + + static final List DATA; + static { + final List data = new ArrayList<>(); + data.add("["); + data.add("{\"timestamp\" : \"123456789\", \"logsource\" : \"syslog\", \"message\" : \"This is a test 1\"},"); + data.add("{\"timestamp\" : \"123456789\", \"logsource\" : \"syslog\", \"message\" : \"This is a test 2\"},"); + data.add("{\"timestamp\" : \"123456789\", \"logsource\" : \"syslog\", \"message\" : \"This is a test 3\"}"); + data.add("]"); + DATA = Collections.unmodifiableList(data); + } + + private ListenTCPRecord proc; + private TestRunner runner; + + @Before + public void setup() throws InitializationException { + proc = new ListenTCPRecord(); + runner = TestRunners.newTestRunner(proc); + runner.setProperty(ListenTCPRecord.PORT, "0"); + + final String readerId = "record-reader"; + final RecordReaderFactory readerFactory = new JsonTreeReader(); + runner.addControllerService(readerId, readerFactory); + runner.setProperty(readerFactory, SchemaAccessUtils.SCHEMA_ACCESS_STRATEGY, SchemaAccessUtils.SCHEMA_TEXT_PROPERTY.getValue()); + runner.setProperty(readerFactory, SchemaAccessUtils.SCHEMA_TEXT, SCHEMA_TEXT); + runner.enableControllerService(readerFactory); + + final String writerId = "record-writer"; + final RecordSetWriterFactory writerFactory = new MockRecordWriter("timestamp, logsource, message"); + runner.addControllerService(writerId, writerFactory); + runner.enableControllerService(writerFactory); + + runner.setProperty(ListenTCPRecord.RECORD_READER, readerId); + runner.setProperty(ListenTCPRecord.RECORD_WRITER, writerId); + } + + @Test + public void testCustomValidate() throws InitializationException { + runner.setProperty(ListenTCPRecord.PORT, "1"); + runner.assertValid(); + + configureProcessorSslContextService(); + runner.setProperty(ListenTCPRecord.CLIENT_AUTH, ""); + runner.assertNotValid(); + + runner.setProperty(ListenTCPRecord.CLIENT_AUTH, SslContextFactory.ClientAuth.REQUIRED.name()); + runner.assertValid(); + } + + @Test + public void testOneRecordPerFlowFile() throws IOException, InterruptedException { + runner.setProperty(ListenTCPRecord.RECORD_BATCH_SIZE, "1"); + + runTCP(DATA, 3, null); + + List mockFlowFiles = runner.getFlowFilesForRelationship(ListenTCPRecord.REL_SUCCESS); + for (int i=0; i < mockFlowFiles.size(); i++) { + final MockFlowFile flowFile = mockFlowFiles.get(i); + flowFile.assertAttributeEquals("record.count", "1"); + + final String content = new String(flowFile.toByteArray(), StandardCharsets.UTF_8); + Assert.assertNotNull(content); + Assert.assertTrue(content.contains("This is a test " + (i + 1))); + } + } + + @Test + public void testMultipleRecordsPerFlowFileLessThanBatchSize() throws IOException, InterruptedException { + runner.setProperty(ListenTCPRecord.RECORD_BATCH_SIZE, "5"); + + runTCP(DATA, 1, null); + + final List mockFlowFiles = runner.getFlowFilesForRelationship(ListenTCPRecord.REL_SUCCESS); + Assert.assertEquals(1, mockFlowFiles.size()); + + final MockFlowFile flowFile = mockFlowFiles.get(0); + flowFile.assertAttributeEquals("record.count", "3"); + + final String content = new String(flowFile.toByteArray(), StandardCharsets.UTF_8); + Assert.assertNotNull(content); + Assert.assertTrue(content.contains("This is a test " + 1)); + Assert.assertTrue(content.contains("This is a test " + 2)); + Assert.assertTrue(content.contains("This is a test " + 3)); + } + + @Test + public void testTLSClienAuthRequiredAndClientCertProvided() throws InitializationException, IOException, InterruptedException, UnrecoverableKeyException, + CertificateException, NoSuchAlgorithmException, KeyStoreException, KeyManagementException { + + runner.setProperty(ListenTCPRecord.CLIENT_AUTH, SSLContextService.ClientAuth.REQUIRED.name()); + configureProcessorSslContextService(); + + // Make an SSLContext with a key and trust store to send the test messages + final SSLContext clientSslContext = SslContextFactory.createSslContext( + "src/test/resources/localhost-ks.jks", + "localtest".toCharArray(), + "jks", + "src/test/resources/localhost-ts.jks", + "localtest".toCharArray(), + "jks", + org.apache.nifi.security.util.SslContextFactory.ClientAuth.valueOf("NONE"), + "TLS"); + + runTCP(DATA, 1, clientSslContext); + + final List mockFlowFiles = runner.getFlowFilesForRelationship(ListenTCPRecord.REL_SUCCESS); + Assert.assertEquals(1, mockFlowFiles.size()); + + final String content = new String(mockFlowFiles.get(0).toByteArray(), StandardCharsets.UTF_8); + Assert.assertNotNull(content); + Assert.assertTrue(content.contains("This is a test " + 1)); + Assert.assertTrue(content.contains("This is a test " + 2)); + Assert.assertTrue(content.contains("This is a test " + 3)); + } + + @Test + public void testTLSClienAuthRequiredAndClientCertNotProvided() throws InitializationException, CertificateException, UnrecoverableKeyException, + NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException, InterruptedException { + + runner.setProperty(ListenTCPRecord.CLIENT_AUTH, SSLContextService.ClientAuth.REQUIRED.name()); + runner.setProperty(ListenTCPRecord.READ_TIMEOUT, "5 seconds"); + configureProcessorSslContextService(); + + // Make an SSLContext that only has the trust store, this should not work since the processor has client auth REQUIRED + final SSLContext clientSslContext = SslContextFactory.createTrustSslContext( + "src/test/resources/localhost-ts.jks", + "localtest".toCharArray(), + "jks", + "TLS"); + + runTCP(DATA, 0, clientSslContext); + } + + @Test + public void testTLSClienAuthNoneAndClientCertNotProvided() throws InitializationException, CertificateException, UnrecoverableKeyException, + NoSuchAlgorithmException, KeyStoreException, KeyManagementException, IOException, InterruptedException { + + runner.setProperty(ListenTCPRecord.CLIENT_AUTH, SSLContextService.ClientAuth.NONE.name()); + configureProcessorSslContextService(); + + // Make an SSLContext that only has the trust store, this should work since the processor has client auth NONE + final SSLContext clientSslContext = SslContextFactory.createTrustSslContext( + "src/test/resources/localhost-ts.jks", + "localtest".toCharArray(), + "jks", + "TLS"); + + runTCP(DATA, 1, clientSslContext); + + final List mockFlowFiles = runner.getFlowFilesForRelationship(ListenTCPRecord.REL_SUCCESS); + Assert.assertEquals(1, mockFlowFiles.size()); + + final String content = new String(mockFlowFiles.get(0).toByteArray(), StandardCharsets.UTF_8); + Assert.assertNotNull(content); + Assert.assertTrue(content.contains("This is a test " + 1)); + Assert.assertTrue(content.contains("This is a test " + 2)); + Assert.assertTrue(content.contains("This is a test " + 3)); + } + + protected void runTCP(final List messages, final int expectedTransferred, final SSLContext sslContext) + throws IOException, InterruptedException { + + SocketSender sender = null; + try { + // schedule to start listening on a random port + final ProcessSessionFactory processSessionFactory = runner.getProcessSessionFactory(); + final ProcessContext context = runner.getProcessContext(); + proc.onScheduled(context); + Thread.sleep(100); + + sender = new SocketSender(proc.getDispatcherPort(), "localhost", sslContext, messages, 0); + + final Thread senderThread = new Thread(sender); + senderThread.setDaemon(true); + senderThread.start(); + + long timeout = 10000; + + // call onTrigger until we processed all the records, or a certain amount of time passes + int numTransferred = 0; + long startTime = System.currentTimeMillis(); + while (numTransferred < expectedTransferred && (System.currentTimeMillis() - startTime < timeout)) { + proc.onTrigger(context, processSessionFactory); + numTransferred = runner.getFlowFilesForRelationship(ListenTCPRecord.REL_SUCCESS).size(); + Thread.sleep(100); + } + + // should have transferred the expected events + runner.assertTransferCount(ListenTCPRecord.REL_SUCCESS, expectedTransferred); + } finally { + // unschedule to close connections + proc.onStopped(); + IOUtils.closeQuietly(sender); + } + } + + private SSLContextService configureProcessorSslContextService() throws InitializationException { + final SSLContextService sslContextService = new StandardSSLContextService(); + runner.addControllerService("ssl-context", sslContextService); + runner.setProperty(sslContextService, StandardSSLContextService.TRUSTSTORE, "src/test/resources/localhost-ts.jks"); + runner.setProperty(sslContextService, StandardSSLContextService.TRUSTSTORE_PASSWORD, "localtest"); + runner.setProperty(sslContextService, StandardSSLContextService.TRUSTSTORE_TYPE, "JKS"); + runner.setProperty(sslContextService, StandardSSLContextService.KEYSTORE, "src/test/resources/localhost-ks.jks"); + runner.setProperty(sslContextService, StandardSSLContextService.KEYSTORE_PASSWORD, "localtest"); + runner.setProperty(sslContextService, StandardSSLContextService.KEYSTORE_TYPE, "JKS"); + runner.enableControllerService(sslContextService); + + runner.setProperty(ListenTCPRecord.SSL_CONTEXT_SERVICE, "ssl-context"); + return sslContextService; + } + + private static class SocketSender implements Runnable, Closeable { + + private final int port; + private final String host; + private final SSLContext sslContext; + private final List data; + private final long delay; + + private Socket socket; + + public SocketSender(final int port, final String host, final SSLContext sslContext, final List data, final long delay) { + this.port = port; + this.host = host; + this.sslContext = sslContext; + this.data = data; + this.delay = delay; + } + + @Override + public void run() { + try { + if (sslContext != null) { + socket = sslContext.getSocketFactory().createSocket(host, port); + } else { + socket = new Socket(host, port); + } + + for (final String message : data) { + socket.getOutputStream().write(message.getBytes(StandardCharsets.UTF_8)); + if (delay > 0) { + Thread.sleep(delay); + } + } + + socket.getOutputStream().flush(); + } catch (final Exception e) { + LOGGER.error(e.getMessage(), e); + } finally { + IOUtils.closeQuietly(socket); + } + } + + public void close() { + IOUtils.closeQuietly(socket); + } + } + +}