NIFI-5952 Refactor RAW S2S from nio to socket

This commit is contained in:
Koji Kawamura 2018-12-25 11:35:52 +09:00 committed by Mark Payne
parent af94b035cb
commit e659e3b606
12 changed files with 388 additions and 557 deletions

View File

@ -34,11 +34,10 @@ import org.apache.nifi.remote.exception.PortNotRunningException;
import org.apache.nifi.remote.exception.TransmissionDisabledException;
import org.apache.nifi.remote.exception.UnknownPortException;
import org.apache.nifi.remote.exception.UnreachableClusterException;
import org.apache.nifi.remote.io.socket.SocketChannelCommunicationsSession;
import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannel;
import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannelCommunicationsSession;
import org.apache.nifi.remote.io.socket.SocketCommunicationsSession;
import org.apache.nifi.remote.protocol.CommunicationsSession;
import org.apache.nifi.remote.protocol.socket.SocketClientProtocol;
import org.apache.nifi.security.util.CertificateUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -49,9 +48,8 @@ import java.io.File;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.Socket;
import java.net.URI;
import java.nio.channels.SocketChannel;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Collections;
@ -450,27 +448,23 @@ public class EndpointConnectionPool implements PeerStatusProvider {
+ " because it requires Secure Site-to-Site communications, but this instance is not configured for secure communications");
}
final SSLSocketChannel socketChannel = new SSLSocketChannel(sslContext, hostname, port, localAddress, true);
socketChannel.connect();
commsSession = new SSLSocketChannelCommunicationsSession(socketChannel);
final Socket socket = sslContext.getSocketFactory().createSocket(hostname, port);
socket.setSoTimeout(commsTimeout);
commsSession = new SocketCommunicationsSession(socket);
try {
commsSession.setUserDn(socketChannel.getDn());
final String dn = CertificateUtils.extractPeerDNFromSSLSocket(socket);
commsSession.setUserDn(dn);
} catch (final CertificateException ex) {
throw new IOException(ex);
}
} else {
final SocketChannel socketChannel = SocketChannel.open();
if (localAddress != null) {
final SocketAddress localSocketAddress = new InetSocketAddress(localAddress, 0);
socketChannel.socket().bind(localSocketAddress);
}
socketChannel.socket().connect(new InetSocketAddress(hostname, port), commsTimeout);
socketChannel.socket().setSoTimeout(commsTimeout);
final Socket socket = new Socket();
socket.connect(new InetSocketAddress(hostname, port), commsTimeout);
socket.setSoTimeout(commsTimeout);
commsSession = new SocketChannelCommunicationsSession(socketChannel);
commsSession = new SocketCommunicationsSession(socket);
}
commsSession.getOutput().getOutputStream().write(CommunicationsSession.MAGIC_BYTES);

View File

@ -16,38 +16,39 @@
*/
package org.apache.nifi.remote.io.socket;
import java.io.IOException;
import java.nio.channels.SocketChannel;
import org.apache.nifi.remote.AbstractCommunicationsSession;
import org.apache.nifi.remote.protocol.CommunicationsInput;
import org.apache.nifi.remote.protocol.CommunicationsOutput;
public class SocketChannelCommunicationsSession extends AbstractCommunicationsSession {
import java.io.IOException;
import java.net.Socket;
private final SocketChannel channel;
private final SocketChannelInput request;
private final SocketChannelOutput response;
public class SocketCommunicationsSession extends AbstractCommunicationsSession {
private final Socket socket;
private final SocketInput request;
private final SocketOutput response;
private int timeout = 30000;
public SocketChannelCommunicationsSession(final SocketChannel socketChannel) throws IOException {
public SocketCommunicationsSession(final Socket socket) throws IOException {
super();
request = new SocketChannelInput(socketChannel);
response = new SocketChannelOutput(socketChannel);
channel = socketChannel;
socketChannel.configureBlocking(false);
this.socket = socket;
request = new SocketInput(socket);
response = new SocketOutput(socket);
}
@Override
public boolean isClosed() {
return !channel.isConnected();
return socket.isClosed();
}
@Override
public SocketChannelInput getInput() {
public CommunicationsInput getInput() {
return request;
}
@Override
public SocketChannelOutput getOutput() {
public CommunicationsOutput getOutput() {
return response;
}
@ -74,7 +75,7 @@ public class SocketChannelCommunicationsSession extends AbstractCommunicationsSe
}
try {
channel.close();
socket.close();
} catch (final IOException ioe) {
if (suppressed != null) {
ioe.addSuppressed(suppressed);

View File

@ -19,21 +19,28 @@ package org.apache.nifi.remote.io.socket;
import org.apache.nifi.remote.io.InterruptableInputStream;
import org.apache.nifi.remote.protocol.CommunicationsInput;
import org.apache.nifi.stream.io.ByteCountingInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.channels.SocketChannel;
import java.net.Socket;
import java.net.SocketException;
public class SocketChannelInput implements CommunicationsInput {
public class SocketInput implements CommunicationsInput {
private final SocketChannelInputStream socketIn;
private static final Logger LOG = LoggerFactory.getLogger(SocketInput.class);
private final Socket socket;
private final InputStream socketIn;
private final ByteCountingInputStream countingIn;
private final InputStream bufferedIn;
private final InterruptableInputStream interruptableIn;
public SocketChannelInput(final SocketChannel socketChannel) throws IOException {
this.socketIn = new SocketChannelInputStream(socketChannel);
public SocketInput(final Socket socket) throws IOException {
this.socket = socket;
socketIn = socket.getInputStream();
countingIn = new ByteCountingInputStream(socketIn);
bufferedIn = new BufferedInputStream(countingIn);
interruptableIn = new InterruptableInputStream(bufferedIn);
@ -45,7 +52,11 @@ public class SocketChannelInput implements CommunicationsInput {
}
public void setTimeout(final int millis) {
socketIn.setTimeout(millis);
try {
socket.setSoTimeout(millis);
} catch (SocketException e) {
LOG.warn("Failed to set socket timeout.", e);
}
}
public boolean isDataAvailable() {
@ -63,11 +74,18 @@ public class SocketChannelInput implements CommunicationsInput {
public void interrupt() {
interruptableIn.interrupt();
socketIn.interrupt();
}
@Override
public void consume() throws IOException {
socketIn.consume();
if (interruptableIn == null || !isDataAvailable()) {
return;
}
final byte[] b = new byte[4096];
int bytesRead;
do {
bytesRead = interruptableIn.read(b);
} while (bytesRead > 0);
}
}

View File

@ -16,24 +16,30 @@
*/
package org.apache.nifi.remote.io.socket;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.SocketChannel;
import org.apache.nifi.remote.io.InterruptableOutputStream;
import org.apache.nifi.remote.protocol.CommunicationsOutput;
import org.apache.nifi.stream.io.ByteCountingOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SocketChannelOutput implements CommunicationsOutput {
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.Socket;
import java.net.SocketException;
private final SocketChannelOutputStream socketOutStream;
public class SocketOutput implements CommunicationsOutput {
private static final Logger LOG = LoggerFactory.getLogger(SocketOutput.class);
private final Socket socket;
private final ByteCountingOutputStream countingOut;
private final OutputStream bufferedOut;
private final InterruptableOutputStream interruptableOut;
public SocketChannelOutput(final SocketChannel socketChannel) throws IOException {
socketOutStream = new SocketChannelOutputStream(socketChannel);
countingOut = new ByteCountingOutputStream(socketOutStream);
public SocketOutput(final Socket socket) throws IOException {
this.socket = socket;
countingOut = new ByteCountingOutputStream(socket.getOutputStream());
bufferedOut = new BufferedOutputStream(countingOut);
interruptableOut = new InterruptableOutputStream(bufferedOut);
}
@ -44,7 +50,11 @@ public class SocketChannelOutput implements CommunicationsOutput {
}
public void setTimeout(final int timeout) {
socketOutStream.setTimeout(timeout);
try {
socket.setSoTimeout(timeout);
} catch (SocketException e) {
LOG.warn("Failed to set socket timeout.", e);
}
}
@Override

View File

@ -1,114 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote.io.socket.ssl;
import java.io.IOException;
import org.apache.nifi.remote.AbstractCommunicationsSession;
public class SSLSocketChannelCommunicationsSession extends AbstractCommunicationsSession {
private final SSLSocketChannel channel;
private final SSLSocketChannelInput request;
private final SSLSocketChannelOutput response;
public SSLSocketChannelCommunicationsSession(final SSLSocketChannel channel) {
super();
request = new SSLSocketChannelInput(channel);
response = new SSLSocketChannelOutput(channel);
this.channel = channel;
}
@Override
public SSLSocketChannelInput getInput() {
return request;
}
@Override
public SSLSocketChannelOutput getOutput() {
return response;
}
@Override
public void setTimeout(final int millis) throws IOException {
channel.setTimeout(millis);
}
@Override
public int getTimeout() throws IOException {
return channel.getTimeout();
}
@Override
public void close() throws IOException {
IOException suppressed = null;
try {
request.consume();
} catch (final IOException ioe) {
suppressed = ioe;
}
try {
channel.close();
} catch (final IOException ioe) {
if (suppressed != null) {
ioe.addSuppressed(suppressed);
}
throw ioe;
}
if (suppressed != null) {
throw suppressed;
}
}
@Override
public boolean isClosed() {
return channel.isClosed();
}
@Override
public boolean isDataAvailable() {
try {
return request.isDataAvailable();
} catch (final Exception e) {
return false;
}
}
@Override
public long getBytesWritten() {
return response.getBytesWritten();
}
@Override
public long getBytesRead() {
return request.getBytesRead();
}
@Override
public void interrupt() {
channel.interrupt();
}
@Override
public String toString() {
return super.toString() + "[SSLSocketChannel=" + channel + "]";
}
}

View File

@ -1,55 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote.io.socket.ssl;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import org.apache.nifi.remote.protocol.CommunicationsInput;
import org.apache.nifi.stream.io.ByteCountingInputStream;
public class SSLSocketChannelInput implements CommunicationsInput {
private final SSLSocketChannelInputStream in;
private final ByteCountingInputStream countingIn;
private final InputStream bufferedIn;
public SSLSocketChannelInput(final SSLSocketChannel socketChannel) {
in = new SSLSocketChannelInputStream(socketChannel);
countingIn = new ByteCountingInputStream(in);
this.bufferedIn = new BufferedInputStream(countingIn);
}
@Override
public InputStream getInputStream() throws IOException {
return bufferedIn;
}
public boolean isDataAvailable() throws IOException {
return bufferedIn.available() > 0;
}
@Override
public long getBytesRead() {
return countingIn.getBytesRead();
}
@Override
public void consume() throws IOException {
in.consume();
}
}

View File

@ -1,44 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote.io.socket.ssl;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import org.apache.nifi.remote.protocol.CommunicationsOutput;
import org.apache.nifi.stream.io.ByteCountingOutputStream;
public class SSLSocketChannelOutput implements CommunicationsOutput {
private final OutputStream out;
private final ByteCountingOutputStream countingOut;
public SSLSocketChannelOutput(final SSLSocketChannel channel) {
countingOut = new ByteCountingOutputStream(new SSLSocketChannelOutputStream(channel));
out = new BufferedOutputStream(countingOut);
}
@Override
public OutputStream getOutputStream() throws IOException {
return out;
}
@Override
public long getBytesWritten() {
return countingOut.getBytesWritten();
}
}

View File

@ -63,7 +63,7 @@ public class SocketClientTransaction extends AbstractTransaction {
this.dataAvailable = true;
break;
case NO_MORE_DATA:
logger.debug("{} No data available from {}", peer);
logger.debug("{} No data available from {}", this, peer);
this.dataAvailable = false;
return;
default:

View File

@ -16,6 +16,31 @@
*/
package org.apache.nifi.remote.protocol.socket;
import org.apache.nifi.events.EventReporter;
import org.apache.nifi.remote.Peer;
import org.apache.nifi.remote.PeerDescription;
import org.apache.nifi.remote.Transaction;
import org.apache.nifi.remote.TransferDirection;
import org.apache.nifi.remote.codec.FlowFileCodec;
import org.apache.nifi.remote.codec.StandardFlowFileCodec;
import org.apache.nifi.remote.exception.NoContentException;
import org.apache.nifi.remote.io.socket.SocketCommunicationsSession;
import org.apache.nifi.remote.io.socket.SocketInput;
import org.apache.nifi.remote.io.socket.SocketOutput;
import org.apache.nifi.remote.protocol.DataPacket;
import org.apache.nifi.remote.protocol.RequestType;
import org.apache.nifi.remote.protocol.Response;
import org.apache.nifi.remote.protocol.ResponseCode;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import static org.apache.nifi.remote.protocol.SiteToSiteTestUtils.createDataPacket;
import static org.apache.nifi.remote.protocol.SiteToSiteTestUtils.execReceiveOneFlowFile;
import static org.apache.nifi.remote.protocol.SiteToSiteTestUtils.execReceiveTwoFlowFiles;
@ -27,34 +52,10 @@ import static org.apache.nifi.remote.protocol.SiteToSiteTestUtils.execSendWithIn
import static org.apache.nifi.remote.protocol.SiteToSiteTestUtils.execSendZeroFlowFile;
import static org.apache.nifi.remote.protocol.SiteToSiteTestUtils.readContents;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import org.apache.nifi.events.EventReporter;
import org.apache.nifi.remote.Peer;
import org.apache.nifi.remote.PeerDescription;
import org.apache.nifi.remote.Transaction;
import org.apache.nifi.remote.TransferDirection;
import org.apache.nifi.remote.codec.FlowFileCodec;
import org.apache.nifi.remote.codec.StandardFlowFileCodec;
import org.apache.nifi.remote.exception.NoContentException;
import org.apache.nifi.remote.io.socket.SocketChannelCommunicationsSession;
import org.apache.nifi.remote.io.socket.SocketChannelInput;
import org.apache.nifi.remote.io.socket.SocketChannelOutput;
import org.apache.nifi.remote.protocol.DataPacket;
import org.apache.nifi.remote.protocol.RequestType;
import org.apache.nifi.remote.protocol.Response;
import org.apache.nifi.remote.protocol.ResponseCode;
import org.junit.Test;
import static org.junit.Assert.fail;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class TestSocketClientTransaction {
private Logger logger = LoggerFactory.getLogger(TestSocketClientTransaction.class);
@ -63,9 +64,9 @@ public class TestSocketClientTransaction {
private SocketClientTransaction getClientTransaction(ByteArrayInputStream bis, ByteArrayOutputStream bos, TransferDirection direction) throws IOException {
PeerDescription description = null;
String peerUrl = "";
SocketChannelCommunicationsSession commsSession = mock(SocketChannelCommunicationsSession.class);
SocketChannelInput socketIn = mock(SocketChannelInput.class);
SocketChannelOutput socketOut = mock(SocketChannelOutput.class);
SocketCommunicationsSession commsSession = mock(SocketCommunicationsSession.class);
SocketInput socketIn = mock(SocketInput.class);
SocketOutput socketOut = mock(SocketOutput.class);
when(commsSession.getInput()).thenReturn(socketIn);
when(commsSession.getOutput()).thenReturn(socketOut);

View File

@ -24,17 +24,17 @@ import org.apache.nifi.remote.exception.BadRequestException;
import org.apache.nifi.remote.exception.HandshakeException;
import org.apache.nifi.remote.exception.NotAuthorizedException;
import org.apache.nifi.remote.exception.RequestExpiredException;
import org.apache.nifi.remote.io.socket.SocketChannelCommunicationsSession;
import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannel;
import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannelCommunicationsSession;
import org.apache.nifi.remote.io.socket.SocketCommunicationsSession;
import org.apache.nifi.remote.protocol.CommunicationsSession;
import org.apache.nifi.remote.protocol.RequestType;
import org.apache.nifi.remote.protocol.ServerProtocol;
import org.apache.nifi.security.util.CertificateUtils;
import org.apache.nifi.util.NiFiProperties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
@ -42,12 +42,9 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@ -90,9 +87,6 @@ public class SocketRemoteSiteListener implements RemoteSiteListener {
final boolean secure = (sslContext != null);
final List<Thread> threads = new ArrayList<>();
final ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.configureBlocking(true);
serverSocketChannel.bind(new InetSocketAddress(socketPort));
stopped.set(false);
final Thread listenerThread = new Thread(new Runnable() {
@ -100,220 +94,248 @@ public class SocketRemoteSiteListener implements RemoteSiteListener {
@Override
public void run() {
while (!stopped.get()) {
LOG.trace("Accepting Connection...");
Socket acceptedSocket = null;
try {
serverSocketChannel.configureBlocking(false);
final ServerSocket serverSocket = serverSocketChannel.socket();
serverSocket.setSoTimeout(2000);
while (!stopped.get() && acceptedSocket == null) {
try {
acceptedSocket = serverSocket.accept();
} catch (final SocketTimeoutException ste) {
continue;
}
}
} catch (final IOException e) {
LOG.error("RemoteSiteListener Unable to accept connection due to {}", e.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", e);
}
continue;
}
LOG.trace("Got connection");
if (stopped.get()) {
break;
try (final ServerSocket serverSocket = createServerSocket()) {
serverSocket.setSoTimeout(2000);
while (!stopped.get()) {
final Socket acceptedSocket = acceptConnection(serverSocket);
if (acceptedSocket == null) {
continue;
}
if (stopped.get()) {
break;
}
final Thread thread = createWorkerThread(acceptedSocket);
thread.setName("Site-to-Site Worker Thread-" + (threadCount++));
LOG.debug("Handing connection to {}", thread);
thread.start();
threads.add(thread);
threads.removeIf(t -> !t.isAlive());
}
final Socket socket = acceptedSocket;
final SocketChannel socketChannel = socket.getChannel();
final Thread thread = new Thread(new Runnable() {
@Override
public void run() {
LOG.debug("{} Determining URL of connection", this);
final InetAddress inetAddress = socket.getInetAddress();
String clientHostName = inetAddress.getHostName();
final int slashIndex = clientHostName.indexOf("/");
if (slashIndex == 0) {
clientHostName = clientHostName.substring(1);
} else if (slashIndex > 0) {
clientHostName = clientHostName.substring(0, slashIndex);
}
final int clientPort = socket.getPort();
final String peerUri = "nifi://" + clientHostName + ":" + clientPort;
LOG.debug("{} Connection URL is {}", this, peerUri);
final CommunicationsSession commsSession;
final String dn;
try {
if (secure) {
final SSLSocketChannel sslSocketChannel = new SSLSocketChannel(sslContext, socketChannel, false);
LOG.trace("Channel is secure; connecting...");
sslSocketChannel.connect();
LOG.trace("Channel connected");
commsSession = new SSLSocketChannelCommunicationsSession(sslSocketChannel);
dn = sslSocketChannel.getDn();
commsSession.setUserDn(dn);
} else {
LOG.trace("{} Channel is not secure", this);
commsSession = new SocketChannelCommunicationsSession(socketChannel);
dn = null;
}
} catch (final Exception e) {
LOG.error("RemoteSiteListener Unable to accept connection from {} due to {}", socket, e.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", e);
}
try {
socketChannel.close();
} catch (IOException swallow) {
}
return;
}
LOG.info("Received connection from {}, User DN: {}", socket.getInetAddress(), dn);
final InputStream socketIn;
final OutputStream socketOut;
try {
socketIn = commsSession.getInput().getInputStream();
socketOut = commsSession.getOutput().getOutputStream();
} catch (final IOException e) {
LOG.error("Connection dropped from {} before any data was transmitted", peerUri);
try {
commsSession.close();
} catch (final IOException ioe) {
}
return;
}
final DataInputStream dis = new DataInputStream(socketIn);
final DataOutputStream dos = new DataOutputStream(socketOut);
ServerProtocol protocol = null;
Peer peer = null;
try {
// ensure that we are communicating with another NiFi
LOG.debug("Verifying magic bytes...");
verifyMagicBytes(dis, peerUri);
LOG.debug("Receiving Server Protocol Negotiation");
protocol = RemoteResourceFactory.receiveServerProtocolNegotiation(dis, dos);
protocol.setRootProcessGroup(rootGroup.get());
protocol.setNodeInformant(nodeInformant);
if (protocol instanceof PeerDescriptionModifiable) {
((PeerDescriptionModifiable)protocol).setPeerDescriptionModifier(peerDescriptionModifier);
}
final PeerDescription description = new PeerDescription(clientHostName, clientPort, sslContext != null);
peer = new Peer(description, commsSession, peerUri, "nifi://localhost:" + getPort());
LOG.debug("Handshaking....");
protocol.handshake(peer);
if (!protocol.isHandshakeSuccessful()) {
LOG.error("Handshake failed with {}; closing connection", peer);
try {
peer.close();
} catch (final IOException e) {
LOG.warn("Failed to close {} due to {}", peer, e);
}
// no need to shutdown protocol because we failed to perform handshake
return;
}
commsSession.setTimeout((int) protocol.getRequestExpiration());
LOG.info("Successfully negotiated ServerProtocol {} Version {} with {}", protocol.getResourceName(), protocol.getVersionNegotiator().getVersion(), peer);
try {
while (!protocol.isShutdown()) {
LOG.trace("Getting Protocol Request Type...");
int timeoutCount = 0;
RequestType requestType = null;
while (requestType == null) {
try {
requestType = protocol.getRequestType(peer);
} catch (final SocketTimeoutException e) {
// Give the timeout a bit longer (twice as long) to receive the Request Type,
// in order to attempt to receive more data without shutting down the socket if we don't
// have to.
LOG.debug("{} Timed out waiting to receive RequestType using {} with {}", this, protocol, peer);
timeoutCount++;
requestType = null;
if (timeoutCount >= 2) {
throw e;
}
}
}
handleRequest(protocol, peer, requestType);
}
LOG.debug("Finished communicating with {} ({})", peer, protocol);
} catch (final Exception e) {
LOG.error("Unable to communicate with remote instance {} ({}) due to {}; closing connection", peer, protocol, e.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", e);
}
}
} catch (final IOException e) {
LOG.error("Unable to communicate with remote instance {} due to {}; closing connection", peer, e.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", e);
}
} catch (final Throwable t) {
LOG.error("Handshake failed when communicating with {}; closing connection. Reason for failure: {}", peerUri, t.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", t);
}
} finally {
LOG.trace("Cleaning up");
try {
if (protocol != null && peer != null) {
protocol.shutdown(peer);
}
} catch (final Exception protocolException) {
LOG.warn("Failed to shutdown protocol due to {}", protocolException.toString());
}
try {
if (peer != null) {
peer.close();
}
} catch (final Exception peerException) {
LOG.warn("Failed to close peer due to {}; some resources may not be appropriately cleaned up", peerException.toString());
}
LOG.trace("Finished cleaning up");
}
}
});
thread.setName("Site-to-Site Worker Thread-" + (threadCount++));
LOG.debug("Handing connection to {}", thread);
thread.start();
threads.add(thread);
threads.removeIf(t -> !t.isAlive());
} catch (final IOException e) {
LOG.error("Unable to open server socket due to {}", e.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", e);
}
}
for(Thread thread : threads) {
if(thread != null) {
thread.interrupt();
}
}
}
private Thread createWorkerThread(Socket socket) {
return new Thread(new Runnable() {
@Override
public void run() {
LOG.debug("{} Determining URL of connection", this);
final InetAddress inetAddress = socket.getInetAddress();
String clientHostName = inetAddress.getHostName();
final int slashIndex = clientHostName.indexOf("/");
if (slashIndex == 0) {
clientHostName = clientHostName.substring(1);
} else if (slashIndex > 0) {
clientHostName = clientHostName.substring(0, slashIndex);
}
final int clientPort = socket.getPort();
final String peerUri = "nifi://" + clientHostName + ":" + clientPort;
LOG.debug("{} Connection URL is {}", this, peerUri);
final CommunicationsSession commsSession;
final String dn;
try {
if (secure) {
LOG.trace("{} Connection is secure", this);
dn = CertificateUtils.extractPeerDNFromSSLSocket(socket);
commsSession = new SocketCommunicationsSession(socket);
commsSession.setUserDn(dn);
} else {
LOG.trace("{} Connection is not secure", this);
commsSession = new SocketCommunicationsSession(socket);
dn = null;
}
} catch (final Exception e) {
LOG.error("RemoteSiteListener Unable to accept connection from {} due to {}", socket, e.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", e);
}
return;
}
LOG.info("Received connection from {}, User DN: {}", socket.getInetAddress(), dn);
final InputStream socketIn;
final OutputStream socketOut;
try {
socketIn = commsSession.getInput().getInputStream();
socketOut = commsSession.getOutput().getOutputStream();
} catch (final IOException e) {
LOG.error("Connection dropped from {} before any data was transmitted", peerUri);
try {
commsSession.close();
} catch (final IOException ioe) {
}
return;
}
final DataInputStream dis = new DataInputStream(socketIn);
final DataOutputStream dos = new DataOutputStream(socketOut);
ServerProtocol protocol = null;
Peer peer = null;
try {
// ensure that we are communicating with another NiFi
LOG.debug("Verifying magic bytes...");
verifyMagicBytes(dis, peerUri);
LOG.debug("Receiving Server Protocol Negotiation");
protocol = RemoteResourceFactory.receiveServerProtocolNegotiation(dis, dos);
protocol.setRootProcessGroup(rootGroup.get());
protocol.setNodeInformant(nodeInformant);
if (protocol instanceof PeerDescriptionModifiable) {
((PeerDescriptionModifiable) protocol).setPeerDescriptionModifier(peerDescriptionModifier);
}
final PeerDescription description = new PeerDescription(clientHostName, clientPort, sslContext != null);
peer = new Peer(description, commsSession, peerUri, "nifi://localhost:" + getPort());
LOG.debug("Handshaking....");
protocol.handshake(peer);
if (!protocol.isHandshakeSuccessful()) {
LOG.error("Handshake failed with {}; closing connection", peer);
try {
peer.close();
} catch (final IOException e) {
LOG.warn("Failed to close {} due to {}", peer, e);
}
// no need to shutdown protocol because we failed to perform handshake
return;
}
commsSession.setTimeout((int) protocol.getRequestExpiration());
LOG.info("Successfully negotiated ServerProtocol {} Version {} with {}",
protocol.getResourceName(), protocol.getVersionNegotiator().getVersion(), peer);
try {
while (!protocol.isShutdown()) {
LOG.trace("Getting Protocol Request Type...");
int timeoutCount = 0;
RequestType requestType = null;
while (requestType == null) {
try {
requestType = protocol.getRequestType(peer);
} catch (final SocketTimeoutException e) {
// Give the timeout a bit longer (twice as long) to receive the Request Type,
// in order to attempt to receive more data without shutting down the socket if we don't
// have to.
LOG.debug("{} Timed out waiting to receive RequestType using {} with {}", this, protocol, peer);
timeoutCount++;
requestType = null;
if (timeoutCount >= 2) {
throw e;
}
}
}
handleRequest(protocol, peer, requestType);
}
LOG.debug("Finished communicating with {} ({})", peer, protocol);
} catch (final Exception e) {
LOG.error("Unable to communicate with remote instance {} ({}) due to {}; closing connection", peer, protocol, e.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", e);
}
}
} catch (final IOException e) {
LOG.error("Unable to communicate with remote instance {} due to {}; closing connection", peer, e.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", e);
}
} catch (final Throwable t) {
LOG.error("Handshake failed when communicating with {}; closing connection. Reason for failure: {}", peerUri, t.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", t);
}
} finally {
LOG.trace("Cleaning up");
try {
if (protocol != null && peer != null) {
protocol.shutdown(peer);
}
} catch (final Exception protocolException) {
LOG.warn("Failed to shutdown protocol due to {}", protocolException.toString());
}
try {
if (peer != null) {
peer.close();
}
} catch (final Exception peerException) {
LOG.warn("Failed to close peer due to {}; some resources may not be appropriately cleaned up", peerException.toString());
}
LOG.trace("Finished cleaning up");
}
}
});
}
});
listenerThread.setName("Site-to-Site Listener");
listenerThread.start();
}
private ServerSocket createServerSocket() throws IOException {
if (sslContext != null) {
final ServerSocket serverSocket = sslContext.getServerSocketFactory().createServerSocket(socketPort);
((SSLServerSocket) serverSocket).setNeedClientAuth(true);
return serverSocket;
} else {
return new ServerSocket(socketPort);
}
}
private Socket acceptConnection(ServerSocket serverSocket) {
LOG.trace("Accepting Connection...");
Socket acceptedSocket = null;
try {
while (!stopped.get() && acceptedSocket == null) {
try {
acceptedSocket = serverSocket.accept();
} catch (final SocketTimeoutException ste) {
LOG.trace("SocketTimeoutException occurred. {}", ste.getMessage());
}
}
} catch (final IOException e) {
LOG.error("RemoteSiteListener Unable to accept connection due to {}", e.toString());
if (LOG.isDebugEnabled()) {
LOG.error("", e);
}
return acceptedSocket;
}
LOG.trace("Got connection");
return acceptedSocket;
}
private void handleRequest(final ServerProtocol protocol, final Peer peer, final RequestType requestType)
throws IOException, NotAuthorizedException, BadRequestException, RequestExpiredException {
LOG.debug("Request type from {} is {}", protocol, requestType);

View File

@ -31,7 +31,7 @@ import org.apache.nifi.provenance.ProvenanceEventType;
import org.apache.nifi.remote.client.SiteToSiteClient;
import org.apache.nifi.remote.client.SiteToSiteClientConfig;
import org.apache.nifi.remote.io.http.HttpCommunicationsSession;
import org.apache.nifi.remote.io.socket.SocketChannelCommunicationsSession;
import org.apache.nifi.remote.io.socket.SocketCommunicationsSession;
import org.apache.nifi.remote.protocol.CommunicationsSession;
import org.apache.nifi.remote.protocol.DataPacket;
import org.apache.nifi.remote.protocol.SiteToSiteTransportProtocol;
@ -44,7 +44,6 @@ import org.junit.BeforeClass;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
@ -61,6 +60,7 @@ import org.apache.nifi.util.NiFiProperties;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
@ -154,28 +154,27 @@ public class TestStandardRemoteGroupPort {
final String peerUrl = "nifi://node1.example.com:9090";
final PeerDescription peerDescription = new PeerDescription("node1.example.com", 9090, true);
try (final SocketChannel socketChannel = SocketChannel.open()) {
final CommunicationsSession commsSession = new SocketChannelCommunicationsSession(socketChannel);
commsSession.setUserDn("nifi.node1.example.com");
final Peer peer = new Peer(peerDescription, commsSession, peerUrl, REMOTE_CLUSTER_URL);
final CommunicationsSession commsSession = mock(SocketCommunicationsSession.class);
when(commsSession.createTransitUri(anyString(), anyString())).thenReturn("nifi://node1.example.com:9090/flowfile-uuid");
when(commsSession.getUserDn()).thenReturn("nifi.node1.example.com");
final Peer peer = new Peer(peerDescription, commsSession, peerUrl, REMOTE_CLUSTER_URL);
doReturn(peer).when(transaction).getCommunicant();
doReturn(peer).when(transaction).getCommunicant();
final MockFlowFile flowFile = processSession.createFlowFile("0123456789".getBytes());
sessionState.getFlowFileQueue().offer(flowFile);
final MockFlowFile flowFile = processSession.createFlowFile("0123456789".getBytes());
sessionState.getFlowFileQueue().offer(flowFile);
port.onTrigger(processContext, processSession);
port.onTrigger(processContext, processSession);
// Assert provenance.
final List<ProvenanceEventRecord> provenanceEvents = sessionState.getProvenanceEvents();
assertEquals(1, provenanceEvents.size());
final ProvenanceEventRecord provenanceEvent = provenanceEvents.get(0);
assertEquals(ProvenanceEventType.SEND, provenanceEvent.getEventType());
assertEquals(peerUrl + "/" + flowFile.getAttribute(CoreAttributes.UUID.key()), provenanceEvent.getTransitUri());
assertEquals("Remote DN=nifi.node1.example.com", provenanceEvent.getDetails());
assertEquals("remote-group-port-id", provenanceEvent.getAttribute(SiteToSiteAttributes.S2S_PORT_ID.key()));
// Assert provenance.
final List<ProvenanceEventRecord> provenanceEvents = sessionState.getProvenanceEvents();
assertEquals(1, provenanceEvents.size());
final ProvenanceEventRecord provenanceEvent = provenanceEvents.get(0);
assertEquals(ProvenanceEventType.SEND, provenanceEvent.getEventType());
assertEquals("nifi://node1.example.com:9090/flowfile-uuid", provenanceEvent.getTransitUri());
assertEquals("Remote DN=nifi.node1.example.com", provenanceEvent.getDetails());
assertEquals("remote-group-port-id", provenanceEvent.getAttribute(SiteToSiteAttributes.S2S_PORT_ID.key()));
}
}
@Test
@ -186,43 +185,42 @@ public class TestStandardRemoteGroupPort {
final String peerUrl = "nifi://node1.example.com:9090";
final PeerDescription peerDescription = new PeerDescription("node1.example.com", 9090, true);
try (final SocketChannel socketChannel = SocketChannel.open()) {
final CommunicationsSession commsSession = new SocketChannelCommunicationsSession(socketChannel);
commsSession.setUserDn("nifi.node1.example.com");
final Peer peer = new Peer(peerDescription, commsSession, peerUrl, REMOTE_CLUSTER_URL);
final CommunicationsSession commsSession = mock(SocketCommunicationsSession.class);
when(commsSession.createTransitUri(anyString(), anyString())).thenReturn("nifi://node1.example.com:9090/flowfile-uuid");
when(commsSession.getUserDn()).thenReturn("nifi.node1.example.com");
final Peer peer = new Peer(peerDescription, commsSession, peerUrl, REMOTE_CLUSTER_URL);
doReturn(peer).when(transaction).getCommunicant();
doReturn(peer).when(transaction).getCommunicant();
final String sourceFlowFileUuid = "flowfile-uuid";
final Map<String, String> attributes = new HashMap<>();
attributes.put(CoreAttributes.UUID.key(), sourceFlowFileUuid);
final byte[] dataPacketContents = "DataPacket Contents".getBytes();
final ByteArrayInputStream dataPacketInputStream = new ByteArrayInputStream(dataPacketContents);
final DataPacket dataPacket = new StandardDataPacket(attributes,
dataPacketInputStream, dataPacketContents.length);
final String sourceFlowFileUuid = "flowfile-uuid";
final Map<String, String> attributes = new HashMap<>();
attributes.put(CoreAttributes.UUID.key(), sourceFlowFileUuid);
final byte[] dataPacketContents = "DataPacket Contents".getBytes();
final ByteArrayInputStream dataPacketInputStream = new ByteArrayInputStream(dataPacketContents);
final DataPacket dataPacket = new StandardDataPacket(attributes,
dataPacketInputStream, dataPacketContents.length);
// Return null when it gets called second time.
doReturn(dataPacket).doReturn(null).when(this.transaction).receive();
// Return null when it gets called second time.
doReturn(dataPacket).doReturn(null).when(this.transaction).receive();
port.onTrigger(processContext, processSession);
port.onTrigger(processContext, processSession);
// Assert provenance.
final List<ProvenanceEventRecord> provenanceEvents = sessionState.getProvenanceEvents();
assertEquals(1, provenanceEvents.size());
final ProvenanceEventRecord provenanceEvent = provenanceEvents.get(0);
assertEquals(ProvenanceEventType.RECEIVE, provenanceEvent.getEventType());
assertEquals(peerUrl + "/" + sourceFlowFileUuid, provenanceEvent.getTransitUri());
assertEquals("Remote DN=nifi.node1.example.com", provenanceEvent.getDetails());
// Assert provenance.
final List<ProvenanceEventRecord> provenanceEvents = sessionState.getProvenanceEvents();
assertEquals(1, provenanceEvents.size());
final ProvenanceEventRecord provenanceEvent = provenanceEvents.get(0);
assertEquals(ProvenanceEventType.RECEIVE, provenanceEvent.getEventType());
assertEquals("nifi://node1.example.com:9090/flowfile-uuid", provenanceEvent.getTransitUri());
assertEquals("Remote DN=nifi.node1.example.com", provenanceEvent.getDetails());
// Assert received flow files.
processSession.assertAllFlowFilesTransferred(Relationship.ANONYMOUS);
final List<MockFlowFile> flowFiles = processSession.getFlowFilesForRelationship(Relationship.ANONYMOUS);
assertEquals(1, flowFiles.size());
final MockFlowFile flowFile = flowFiles.get(0);
flowFile.assertAttributeEquals(SiteToSiteAttributes.S2S_HOST.key(), peer.getHost());
flowFile.assertAttributeEquals(SiteToSiteAttributes.S2S_ADDRESS.key(), peer.getHost() + ":" + peer.getPort());
flowFile.assertAttributeEquals(SiteToSiteAttributes.S2S_PORT_ID.key(), "remote-group-port-id");
}
// Assert received flow files.
processSession.assertAllFlowFilesTransferred(Relationship.ANONYMOUS);
final List<MockFlowFile> flowFiles = processSession.getFlowFilesForRelationship(Relationship.ANONYMOUS);
assertEquals(1, flowFiles.size());
final MockFlowFile flowFile = flowFiles.get(0);
flowFile.assertAttributeEquals(SiteToSiteAttributes.S2S_HOST.key(), peer.getHost());
flowFile.assertAttributeEquals(SiteToSiteAttributes.S2S_ADDRESS.key(), peer.getHost() + ":" + peer.getPort());
flowFile.assertAttributeEquals(SiteToSiteAttributes.S2S_PORT_ID.key(), "remote-group-port-id");
}

View File

@ -21,9 +21,9 @@ import org.apache.nifi.remote.PeerDescription;
import org.apache.nifi.remote.StandardVersionNegotiator;
import org.apache.nifi.remote.cluster.ClusterNodeInformation;
import org.apache.nifi.remote.cluster.NodeInformation;
import org.apache.nifi.remote.io.socket.SocketChannelCommunicationsSession;
import org.apache.nifi.remote.io.socket.SocketChannelInput;
import org.apache.nifi.remote.io.socket.SocketChannelOutput;
import org.apache.nifi.remote.io.socket.SocketCommunicationsSession;
import org.apache.nifi.remote.io.socket.SocketInput;
import org.apache.nifi.remote.io.socket.SocketOutput;
import org.apache.nifi.remote.protocol.HandshakeProperties;
import org.apache.nifi.remote.protocol.HandshakeProperty;
import org.apache.nifi.remote.protocol.Response;
@ -75,9 +75,9 @@ public class TestSocketFlowFileServerProtocol {
final InputStream inputStream = new ByteArrayInputStream(inputBytes);
final SocketChannelCommunicationsSession commsSession = mock(SocketChannelCommunicationsSession.class);
final SocketChannelInput channelInput = mock(SocketChannelInput.class);
final SocketChannelOutput channelOutput = mock(SocketChannelOutput.class);
final SocketCommunicationsSession commsSession = mock(SocketCommunicationsSession.class);
final SocketInput channelInput = mock(SocketInput.class);
final SocketOutput channelOutput = mock(SocketOutput.class);
when(commsSession.getInput()).thenReturn(channelInput);
when(commsSession.getOutput()).thenReturn(channelOutput);