NIFI-9507 Corrected SSH Client handling on connect failures

- Refactored SSH Client configuration and connection to SSHClientProvider
- Implemented exception handling for configuration and connection failures
- Named SSH keep-alive thread for improved runtime tracking
- Closed SSH Client and interrupted keep-alive thread on configuration failures
- Added missing Compression Property to ListSFTP
- Corrected Hostname and Port property descriptors in ListSFTP
This commit is contained in:
exceptionfactory 2021-12-20 16:35:28 -06:00 committed by Joe Witt
parent 91f5cc3763
commit 898f9a48bc
No known key found for this signature in database
GPG Key ID: 9093BF854F811A1A
15 changed files with 917 additions and 364 deletions

View File

@ -79,13 +79,11 @@ public class ListSFTP extends ListFileTransfer {
@Override
protected List<PropertyDescriptor> getSupportedPropertyDescriptors() {
final PropertyDescriptor port = new PropertyDescriptor.Builder().fromPropertyDescriptor(UNDEFAULTED_PORT).defaultValue("22").build();
final List<PropertyDescriptor> properties = new ArrayList<>();
properties.add(FILE_TRANSFER_LISTING_STRATEGY);
properties.add(HOSTNAME);
properties.add(port);
properties.add(USERNAME);
properties.add(SFTPTransfer.HOSTNAME);
properties.add(SFTPTransfer.PORT);
properties.add(SFTPTransfer.USERNAME);
properties.add(SFTPTransfer.PASSWORD);
properties.add(SFTPTransfer.PRIVATE_KEY_PATH);
properties.add(SFTPTransfer.PRIVATE_KEY_PASSPHRASE);
@ -103,6 +101,7 @@ public class ListSFTP extends ListFileTransfer {
properties.add(SFTPTransfer.DATA_TIMEOUT);
properties.add(SFTPTransfer.USE_KEEPALIVE_ON_TIMEOUT);
properties.add(TARGET_SYSTEM_TIMESTAMP_PRECISION);
properties.add(SFTPTransfer.USE_COMPRESSION);
properties.add(SFTPTransfer.PROXY_CONFIGURATION_SERVICE);
properties.add(FTPTransfer.PROXY_TYPE);
properties.add(FTPTransfer.PROXY_HOST);

View File

@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
/**
* Client Authentication Exception for authentication failures during SSH Client configuration
*/
public class ClientAuthenticationException extends ClientConfigurationException {
public ClientAuthenticationException(final String message, final Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
import org.apache.nifi.processor.exception.ProcessException;
/**
* Client Configuration Exception for specific failures during SSH Client configuration
*/
public class ClientConfigurationException extends ProcessException {
public ClientConfigurationException(final String message, final Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
/**
* Client Connect Exception for connection failures during SSH Client configuration
*/
public class ClientConnectException extends ClientConfigurationException {
public ClientConnectException(final String message, final Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
import javax.net.SocketFactory;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.Socket;
import java.util.Objects;
/**
* Proxy Socket Factory implementation creates Sockets using the configured Proxy
*/
public class ProxySocketFactory extends SocketFactory {
private final Proxy proxy;
public ProxySocketFactory(final Proxy proxy) {
this.proxy = Objects.requireNonNull(proxy, "Proxy required");
}
@Override
public Socket createSocket() {
return new Socket(proxy);
}
@Override
public Socket createSocket(final String host, final int port) throws IOException {
final InetSocketAddress socketAddress = new InetSocketAddress(host, port);
return createSocket(socketAddress);
}
@Override
public Socket createSocket(final String host, final int port, final InetAddress localHost, final int localPort) throws IOException {
final InetSocketAddress socketAddress = new InetSocketAddress(host, port);
final InetSocketAddress bindSocketAddress = new InetSocketAddress(localHost, localPort);
return createSocket(socketAddress, bindSocketAddress);
}
@Override
public Socket createSocket(final InetAddress host, final int port) throws IOException {
final InetSocketAddress socketAddress = new InetSocketAddress(host, port);
return createSocket(socketAddress);
}
@Override
public Socket createSocket(final InetAddress host, final int port, final InetAddress localAddress, final int localPort) throws IOException {
final InetSocketAddress socketAddress = new InetSocketAddress(host, port);
final InetSocketAddress bindSocketAddress = new InetSocketAddress(localAddress, localPort);
return createSocket(socketAddress, bindSocketAddress);
}
private Socket createSocket(final InetSocketAddress socketAddress, final InetSocketAddress bindSocketAddress) throws IOException {
final Socket socket = createSocket();
socket.bind(bindSocketAddress);
socket.connect(socketAddress);
return socket;
}
private Socket createSocket(final InetSocketAddress socketAddress) throws IOException {
final Socket socket = createSocket();
socket.connect(socketAddress);
return socket;
}
}

View File

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
import net.schmizz.sshj.SSHClient;
import org.apache.nifi.context.PropertyContext;
import java.util.Map;
/**
* SSH Client Provider for abstracting initial connection configuration of SSH Client instances
*/
public interface SSHClientProvider {
/**
* Get configured SSH Client using configured properties
*
* @param context Property Context
* @param attributes FlowFile attributes for property expression evaluation
* @return Configured SSH Client
*/
SSHClient getClient(PropertyContext context, final Map<String, String> attributes);
}

View File

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
import net.schmizz.sshj.Config;
import org.apache.nifi.context.PropertyContext;
/**
* Configuration Provider for SSHJ
*/
public interface SSHConfigProvider {
/**
* Get SSH Configuration using configured properties
*
* @param identifier SSH Client identifier for runtime tracking
* @param context Property Context
* @return SSH Configuration
*/
Config getConfig(final String identifier, final PropertyContext context);
}

View File

@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import net.schmizz.sshj.userauth.keyprovider.KeyFormat;
import net.schmizz.sshj.userauth.keyprovider.KeyProvider;
import net.schmizz.sshj.userauth.keyprovider.KeyProviderUtil;
import net.schmizz.sshj.userauth.method.AuthKeyboardInteractive;
import net.schmizz.sshj.userauth.method.AuthMethod;
import net.schmizz.sshj.userauth.method.AuthPassword;
import net.schmizz.sshj.userauth.method.AuthPublickey;
import net.schmizz.sshj.userauth.method.PasswordResponseProvider;
import net.schmizz.sshj.userauth.password.PasswordFinder;
import net.schmizz.sshj.userauth.password.PasswordUtils;
import org.apache.nifi.context.PropertyContext;
import org.apache.nifi.proxy.ProxyConfiguration;
import org.apache.nifi.util.StringUtils;
import javax.net.SocketFactory;
import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import static org.apache.nifi.processors.standard.util.FTPTransfer.createComponentProxyConfigSupplier;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.DATA_TIMEOUT;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.CONNECTION_TIMEOUT;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.PORT;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.USERNAME;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.PASSWORD;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.HOSTNAME;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.HOST_KEY_FILE;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.PRIVATE_KEY_PASSPHRASE;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.PRIVATE_KEY_PATH;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.STRICT_HOST_KEY_CHECKING;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.USE_COMPRESSION;
/**
* Standard implementation of SSH Client Provider
*/
public class StandardSSHClientProvider implements SSHClientProvider {
private static final SSHConfigProvider SSH_CONFIG_PROVIDER = new StandardSSHConfigProvider();
private static final List<Proxy.Type> SUPPORTED_PROXY_TYPES = Arrays.asList(Proxy.Type.HTTP, Proxy.Type.SOCKS);
private static final String ADDRESS_FORMAT = "%s:%d";
/**
* Get configured and authenticated SSH Client based on context properties
*
* @param context Property Context
* @param attributes FlowFile attributes for property expression evaluation
* @return Authenticated SSH Client
*/
@Override
public SSHClient getClient(final PropertyContext context, final Map<String, String> attributes) {
Objects.requireNonNull(context, "Property Context required");
Objects.requireNonNull(attributes, "Attributes required");
final String hostname = context.getProperty(HOSTNAME).evaluateAttributeExpressions(attributes).getValue();
final int port = context.getProperty(PORT).evaluateAttributeExpressions(attributes).asInteger();
final String address = String.format(ADDRESS_FORMAT, hostname, port);
final String username = context.getProperty(USERNAME).evaluateAttributeExpressions(attributes).getValue();
final List<AuthMethod> authMethods = getPasswordAuthMethods(context, attributes);
final Config config = SSH_CONFIG_PROVIDER.getConfig(address, context);
final SSHClient client = new SSHClient(config);
try {
setClientProperties(client, context);
} catch (final Exception e) {
closeClient(client);
throw new ClientConfigurationException(String.format("SSH Client configuration failed [%s]", address), e);
}
try {
client.connect(hostname, port);
} catch (final Exception e) {
closeClient(client);
throw new ClientConnectException(String.format("SSH Client connection failed [%s]", address), e);
}
try {
final List<AuthMethod> publicKeyAuthMethods = getPublicKeyAuthMethods(client, context, attributes);
authMethods.addAll(publicKeyAuthMethods);
client.auth(username, authMethods);
} catch (final Exception e) {
closeClient(client);
throw new ClientAuthenticationException(String.format("SSH Client authentication failed [%s]", address), e);
}
return client;
}
private void closeClient(final SSHClient client) {
try {
client.close();
} catch (final IOException e) {
throw new UncheckedIOException("SSH Client close failed", e);
} finally {
final Connection connection = client.getConnection();
final KeepAlive keepAlive = connection.getKeepAlive();
keepAlive.interrupt();
}
}
private void setClientProperties(final SSHClient client, final PropertyContext context) {
final int connectionTimeout = context.getProperty(CONNECTION_TIMEOUT).asTimePeriod(TimeUnit.MILLISECONDS).intValue();
client.setConnectTimeout(connectionTimeout);
final int dataTimeout = context.getProperty(DATA_TIMEOUT).asTimePeriod(TimeUnit.MILLISECONDS).intValue();
client.setTimeout(dataTimeout);
final boolean strictHostKeyChecking = context.getProperty(STRICT_HOST_KEY_CHECKING).asBoolean();
final String hostKeyFilePath = context.getProperty(HOST_KEY_FILE).getValue();
if (StringUtils.isNotBlank(hostKeyFilePath)) {
final File knownHosts = new File(hostKeyFilePath);
try {
client.loadKnownHosts(knownHosts);
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Loading Known Hosts [%s] Failed", hostKeyFilePath), e);
}
} else if (strictHostKeyChecking) {
try {
client.loadKnownHosts();
} catch (final IOException e) {
throw new UncheckedIOException("Loading Known Hosts Failed", e);
}
} else {
client.addHostKeyVerifier(new PromiscuousVerifier());
}
final boolean compressionEnabled = context.getProperty(USE_COMPRESSION).asBoolean();
if (compressionEnabled) {
try {
client.useCompression();
} catch (final TransportException e) {
throw new UncheckedIOException("Enabling Compression Failed", e);
}
}
final ProxyConfiguration proxyConfiguration = ProxyConfiguration.getConfiguration(context, createComponentProxyConfigSupplier(context));
final Proxy.Type proxyType = proxyConfiguration.getProxyType();
if (SUPPORTED_PROXY_TYPES.contains(proxyType)) {
final Proxy proxy = proxyConfiguration.createProxy();
final SocketFactory socketFactory = new ProxySocketFactory(proxy);
client.setSocketFactory(socketFactory);
}
}
private List<AuthMethod> getPasswordAuthMethods(final PropertyContext context, final Map<String, String> attributes) {
final List<AuthMethod> passwordAuthMethods = new ArrayList<>();
final String password = context.getProperty(PASSWORD).evaluateAttributeExpressions(attributes).getValue();
if (password != null) {
final AuthMethod authPassword = new AuthPassword(getPasswordFinder(password));
passwordAuthMethods.add(authPassword);
final PasswordResponseProvider passwordProvider = new PasswordResponseProvider(getPasswordFinder(password));
final AuthMethod authKeyboardInteractive = new AuthKeyboardInteractive(passwordProvider);
passwordAuthMethods.add(authKeyboardInteractive);
}
return passwordAuthMethods;
}
private List<AuthMethod> getPublicKeyAuthMethods(final SSHClient client, final PropertyContext context, final Map<String, String> attributes) {
final List<AuthMethod> publicKeyAuthMethods = new ArrayList<>();
final String privateKeyPath = context.getProperty(PRIVATE_KEY_PATH).evaluateAttributeExpressions(attributes).getValue();
if (privateKeyPath != null) {
final String privateKeyPassphrase = context.getProperty(PRIVATE_KEY_PASSPHRASE).evaluateAttributeExpressions(attributes).getValue();
final KeyProvider keyProvider = getKeyProvider(client, privateKeyPath, privateKeyPassphrase);
final AuthMethod authPublicKey = new AuthPublickey(keyProvider);
publicKeyAuthMethods.add(authPublicKey);
}
return publicKeyAuthMethods;
}
private KeyProvider getKeyProvider(final SSHClient client, final String privateKeyLocation, final String privateKeyPassphrase) {
final KeyFormat keyFormat = getKeyFormat(privateKeyLocation);
try {
return privateKeyPassphrase == null ? client.loadKeys(privateKeyLocation) : client.loadKeys(privateKeyLocation, privateKeyPassphrase);
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Loading Private Key File [%s] Format [%s] Failed", privateKeyLocation, keyFormat), e);
}
}
private KeyFormat getKeyFormat(final String privateKeyLocation) {
try {
final File privateKeyFile = new File(privateKeyLocation);
return KeyProviderUtil.detectKeyFileFormat(privateKeyFile);
} catch (final IOException e) {
throw new UncheckedIOException(String.format("Reading Private Key File [%s] Format Failed", privateKeyLocation), e);
}
}
private PasswordFinder getPasswordFinder(final String password) {
return PasswordUtils.createOneOff(password.toCharArray());
}
}

View File

@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.connection.ConnectionImpl;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.PropertyValue;
import org.apache.nifi.context.PropertyContext;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.KEY_EXCHANGE_ALGORITHMS_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.CIPHERS_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.KEY_ALGORITHMS_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.MESSAGE_AUTHENTICATION_CODES_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.USE_KEEPALIVE_ON_TIMEOUT;
/**
* Standard implementation of SSH Configuration Provider
*/
public class StandardSSHConfigProvider implements SSHConfigProvider {
private static final String COMMA_SEPARATOR = ",";
private static final int KEEP_ALIVE_INTERVAL_SECONDS = 5;
private static final KeepAliveProvider DISABLED_KEEP_ALIVE_PROVIDER = new DisabledKeepAliveProvider();
/**
* Get SSH configuration based on configured properties
*
* @param identifier SSH Client identifier for runtime tracking
* @param context Property Context
* @return SSH Configuration
*/
@Override
public Config getConfig(final String identifier, final PropertyContext context) {
final DefaultConfig config = new DefaultConfig();
final KeepAliveProvider keepAliveProvider = getKeepAliveProvider(identifier, context);
config.setKeepAliveProvider(keepAliveProvider);
getOptionalProperty(context, CIPHERS_ALLOWED).ifPresent(property -> config.setCipherFactories(getFilteredValues(property, config.getCipherFactories())));
getOptionalProperty(context, KEY_ALGORITHMS_ALLOWED).ifPresent(property -> config.setKeyAlgorithms(getFilteredValues(property, config.getKeyAlgorithms())));
getOptionalProperty(context, KEY_EXCHANGE_ALGORITHMS_ALLOWED).ifPresent(property -> config.setKeyExchangeFactories(getFilteredValues(property, config.getKeyExchangeFactories())));
getOptionalProperty(context, MESSAGE_AUTHENTICATION_CODES_ALLOWED).ifPresent(property -> config.setMACFactories(getFilteredValues(property, config.getMACFactories())));
return config;
}
private Optional<String> getOptionalProperty(final PropertyContext context, final PropertyDescriptor propertyDescriptor) {
final PropertyValue propertyValue = context.getProperty(propertyDescriptor);
return propertyValue.isSet() ? Optional.of(propertyValue.evaluateAttributeExpressions().getValue()) : Optional.empty();
}
private <T> List<Factory.Named<T>> getFilteredValues(final String propertyValue, final List<Factory.Named<T>> supportedValues) {
final Set<String> configuredValues = getCommaSeparatedValues(propertyValue);
return supportedValues.stream().filter(named -> configuredValues.contains(named.getName())).collect(Collectors.toList());
}
private Set<String> getCommaSeparatedValues(final String propertyValue) {
final String[] values = propertyValue.split(COMMA_SEPARATOR);
return Arrays.stream(values).map(String::trim).collect(Collectors.toSet());
}
private KeepAliveProvider getKeepAliveProvider(final String identifier, final PropertyContext context) {
final boolean keepAliveEnabled = context.getProperty(USE_KEEPALIVE_ON_TIMEOUT).asBoolean();
return keepAliveEnabled ? new EnabledKeepAliveProvider(identifier) : DISABLED_KEEP_ALIVE_PROVIDER;
}
private static class EnabledKeepAliveProvider extends KeepAliveProvider {
private final String identifier;
private EnabledKeepAliveProvider(final String identifier) {
this.identifier = identifier;
}
@Override
public KeepAlive provide(final ConnectionImpl connection) {
final KeepAlive keepAlive = KeepAliveProvider.KEEP_ALIVE.provide(connection);
keepAlive.setName(String.format("SSH-keep-alive-%s", identifier));
keepAlive.setKeepAliveInterval(KEEP_ALIVE_INTERVAL_SECONDS);
return keepAlive;
}
}
private static class DisabledKeepAliveProvider extends KeepAliveProvider {
@Override
public KeepAlive provide(final ConnectionImpl connection) {
return new DisabledKeepAliveThread(connection);
}
}
private static class DisabledKeepAliveThread extends KeepAlive {
private DisabledKeepAliveThread(final ConnectionImpl connection) {
super(connection, "keep-alive-disabled");
}
@Override
public void run() {
}
@Override
protected void doKeepAlive() {
}
}
}

View File

@ -16,13 +16,9 @@
*/
package org.apache.nifi.processors.standard.util;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.sftp.FileAttributes;
import net.schmizz.sshj.sftp.FileMode;
import net.schmizz.sshj.sftp.RemoteFile;
@ -31,17 +27,6 @@ import net.schmizz.sshj.sftp.RemoteResourceInfo;
import net.schmizz.sshj.sftp.Response;
import net.schmizz.sshj.sftp.SFTPClient;
import net.schmizz.sshj.sftp.SFTPException;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import net.schmizz.sshj.userauth.keyprovider.KeyFormat;
import net.schmizz.sshj.userauth.keyprovider.KeyProvider;
import net.schmizz.sshj.userauth.keyprovider.KeyProviderUtil;
import net.schmizz.sshj.userauth.method.AuthKeyboardInteractive;
import net.schmizz.sshj.userauth.method.AuthMethod;
import net.schmizz.sshj.userauth.method.AuthPassword;
import net.schmizz.sshj.userauth.method.AuthPublickey;
import net.schmizz.sshj.userauth.method.PasswordResponseProvider;
import net.schmizz.sshj.userauth.password.PasswordFinder;
import net.schmizz.sshj.userauth.password.PasswordUtils;
import net.schmizz.sshj.xfer.FilePermission;
import net.schmizz.sshj.xfer.LocalSourceFile;
import org.apache.nifi.components.PropertyDescriptor;
@ -58,39 +43,35 @@ import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.io.OutputStreamCallback;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.processors.standard.ssh.SSHClientProvider;
import org.apache.nifi.processors.standard.ssh.StandardSSHClientProvider;
import org.apache.nifi.proxy.ProxyConfiguration;
import org.apache.nifi.proxy.ProxySpec;
import org.apache.nifi.stream.io.StreamUtils;
import javax.net.SocketFactory;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.Proxy;
import java.net.Socket;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static org.apache.nifi.processors.standard.util.FTPTransfer.createComponentProxyConfigSupplier;
public class SFTPTransfer implements FileTransfer {
private static final int KEEP_ALIVE_INTERVAL_SECONDS = 5;
private static final SSHClientProvider SSH_CLIENT_PROVIDER = new StandardSSHClientProvider();
private static final Set<String> DEFAULT_KEY_ALGORITHM_NAMES;
private static final Set<String> DEFAULT_CIPHER_NAMES;
@ -561,32 +542,6 @@ public class SFTPTransfer implements FileTransfer {
}
}
private static final KeepAliveProvider NO_OP_KEEP_ALIVE = new KeepAliveProvider() {
@Override
public KeepAlive provide(final ConnectionImpl connection) {
return new KeepAlive(connection, "no-op-keep-alive") {
@Override
protected void doKeepAlive() {
// do nothing;
}
};
}
};
private static final KeepAliveProvider DEFAULT_KEEP_ALIVE_PROVIDER = new KeepAliveProvider() {
@Override
public KeepAlive provide(final ConnectionImpl connection) {
final KeepAlive keepAlive = KeepAliveProvider.KEEP_ALIVE.provide(connection);
keepAlive.setKeepAliveInterval(KEEP_ALIVE_INTERVAL_SECONDS);
return keepAlive;
}
};
protected KeepAliveProvider getKeepAliveProvider() {
final boolean useKeepAliveOnTimeout = ctx.getProperty(USE_KEEPALIVE_ON_TIMEOUT).asBoolean();
return useKeepAliveOnTimeout ? DEFAULT_KEEP_ALIVE_PROVIDER : NO_OP_KEEP_ALIVE;
}
protected SFTPClient getSFTPClient(final FlowFile flowFile) throws IOException {
// If the client is already initialized then compare the host that the client is connected to with the current
// host from the properties/flow-file, and if different then we need to close and reinitialize, if same we can reuse
@ -602,95 +557,8 @@ public class SFTPTransfer implements FileTransfer {
}
}
// Initialize a new SSHClient...
final DefaultConfig sshClientConfig = new DefaultConfig();
sshClientConfig.setKeepAliveProvider(getKeepAliveProvider());
updateConfigAlgorithms(sshClientConfig);
final SSHClient sshClient = new SSHClient(sshClientConfig);
// Create a Proxy if the config was specified, proxy will be null if type was NO_PROXY
final Proxy proxy;
final ProxyConfiguration proxyConfig = ProxyConfiguration.getConfiguration(ctx, createComponentProxyConfigSupplier(ctx));
switch (proxyConfig.getProxyType()) {
case HTTP:
case SOCKS:
proxy = proxyConfig.createProxy();
break;
default:
proxy = null;
break;
}
// If a proxy was specified, configure the client to use a SocketFactory that creates Sockets using the proxy
if (proxy != null) {
sshClient.setSocketFactory(new SocketFactory() {
@Override
public Socket createSocket() {
return new Socket(proxy);
}
@Override
public Socket createSocket(String s, int i) {
return new Socket(proxy);
}
@Override
public Socket createSocket(String s, int i, InetAddress inetAddress, int i1) {
return new Socket(proxy);
}
@Override
public Socket createSocket(InetAddress inetAddress, int i) {
return new Socket(proxy);
}
@Override
public Socket createSocket(InetAddress inetAddress, int i, InetAddress inetAddress1, int i1) {
return new Socket(proxy);
}
});
}
// If strict host key checking is false, add a HostKeyVerifier that always returns true
final boolean strictHostKeyChecking = ctx.getProperty(STRICT_HOST_KEY_CHECKING).asBoolean();
if (!strictHostKeyChecking) {
sshClient.addHostKeyVerifier(new PromiscuousVerifier());
}
// Load known hosts file if specified, otherwise load default
final String hostKeyVal = ctx.getProperty(HOST_KEY_FILE).getValue();
if (hostKeyVal != null) {
sshClient.loadKnownHosts(new File(hostKeyVal));
// Load default known_hosts file only when 'Strict Host Key Checking' property is enabled
} else if (strictHostKeyChecking) {
sshClient.loadKnownHosts();
}
// Enable compression on the client if specified in properties
final PropertyValue compressionValue = ctx.getProperty(FileTransfer.USE_COMPRESSION);
if (compressionValue != null && "true".equalsIgnoreCase(compressionValue.getValue())) {
sshClient.useCompression();
}
// Configure connection timeout
final int connectionTimeoutMillis = ctx.getProperty(FileTransfer.CONNECTION_TIMEOUT).asTimePeriod(TimeUnit.MILLISECONDS).intValue();
sshClient.setTimeout(connectionTimeoutMillis);
// Connect to the host and port
final String hostname = ctx.getProperty(HOSTNAME).evaluateAttributeExpressions(flowFile).getValue();
final int port = ctx.getProperty(PORT).evaluateAttributeExpressions(flowFile).asInteger();
sshClient.connect(hostname, port);
// Setup authentication methods...
final List<AuthMethod> authMethods = getAuthMethods(sshClient, flowFile);
// Authenticate...
final String username = ctx.getProperty(USERNAME).evaluateAttributeExpressions(flowFile).getValue();
sshClient.auth(username, authMethods);
// At this point we are connected and can create a new SFTPClient which means everything is good
this.sshClient = sshClient;
final Map<String, String> attributes = flowFile == null ? Collections.emptyMap() : flowFile.getAttributes();
this.sshClient = SSH_CLIENT_PROVIDER.getClient(ctx, attributes);
this.sftpClient = sshClient.newSFTPClient();
this.closed = false;
@ -705,50 +573,12 @@ public class SFTPTransfer implements FileTransfer {
this.homeDir = "";
// For some combination of server configuration and user home directory, getHome() can fail with "2: File not found"
// Since homeDir is only used tor SEND provenance event transit uri, this is harmless. Log and continue.
logger.debug("Failed to retrieve {} home directory due to {}", username, e.getMessage());
logger.debug("Failed to retrieve home directory due to {}", e.getMessage());
}
return sftpClient;
}
void updateConfigAlgorithms(final Config config) {
if (ctx.getProperty(CIPHERS_ALLOWED).isSet()) {
Set<String> allowedCiphers = Arrays.stream(ctx.getProperty(CIPHERS_ALLOWED).evaluateAttributeExpressions().getValue().split(","))
.map(String::trim)
.collect(Collectors.toSet());
config.setCipherFactories(config.getCipherFactories().stream()
.filter(cipherNamed -> allowedCiphers.contains(cipherNamed.getName()))
.collect(Collectors.toList()));
}
if (ctx.getProperty(KEY_ALGORITHMS_ALLOWED).isSet()) {
Set<String> allowedKeyAlgorithms = Arrays.stream(ctx.getProperty(KEY_ALGORITHMS_ALLOWED).evaluateAttributeExpressions().getValue().split(","))
.map(String::trim)
.collect(Collectors.toSet());
config.setKeyAlgorithms(config.getKeyAlgorithms().stream()
.filter(keyAlgorithmNamed -> allowedKeyAlgorithms.contains(keyAlgorithmNamed.getName()))
.collect(Collectors.toList()));
}
if (ctx.getProperty(KEY_EXCHANGE_ALGORITHMS_ALLOWED).isSet()) {
Set<String> allowedKeyExchangeAlgorithms = Arrays.stream(ctx.getProperty(KEY_EXCHANGE_ALGORITHMS_ALLOWED).evaluateAttributeExpressions().getValue().split(","))
.map(String::trim)
.collect(Collectors.toSet());
config.setKeyExchangeFactories(config.getKeyExchangeFactories().stream()
.filter(keyExchangeNamed -> allowedKeyExchangeAlgorithms.contains(keyExchangeNamed.getName()))
.collect(Collectors.toList()));
}
if (ctx.getProperty(MESSAGE_AUTHENTICATION_CODES_ALLOWED).isSet()) {
Set<String> allowedMessageAuthenticationCodes = Arrays.stream(ctx.getProperty(MESSAGE_AUTHENTICATION_CODES_ALLOWED).evaluateAttributeExpressions().getValue().split(","))
.map(String::trim)
.collect(Collectors.toSet());
config.setMACFactories(config.getMACFactories().stream()
.filter(macNamed -> allowedMessageAuthenticationCodes.contains(macNamed.getName()))
.collect(Collectors.toList()));
}
}
@Override
public String getHomeDirectory(final FlowFile flowFile) throws IOException {
getSFTPClient(flowFile);
@ -961,55 +791,4 @@ public class SFTPTransfer implements FileTransfer {
}
return number;
}
protected List<AuthMethod> getAuthMethods(final SSHClient client, final FlowFile flowFile) {
final List<AuthMethod> authMethods = new ArrayList<>();
final String privateKeyPath = ctx.getProperty(PRIVATE_KEY_PATH).evaluateAttributeExpressions(flowFile).getValue();
if (privateKeyPath != null) {
final String privateKeyPassphrase = ctx.getProperty(PRIVATE_KEY_PASSPHRASE).evaluateAttributeExpressions(flowFile).getValue();
final KeyProvider keyProvider = getKeyProvider(client, privateKeyPath, privateKeyPassphrase);
final AuthMethod authPublicKey = new AuthPublickey(keyProvider);
authMethods.add(authPublicKey);
}
final String password = ctx.getProperty(FileTransfer.PASSWORD).evaluateAttributeExpressions(flowFile).getValue();
if (password != null) {
final AuthMethod authPassword = new AuthPassword(getPasswordFinder(password));
authMethods.add(authPassword);
final PasswordResponseProvider passwordProvider = new PasswordResponseProvider(getPasswordFinder(password));
final AuthMethod authKeyboardInteractive = new AuthKeyboardInteractive(passwordProvider);
authMethods.add(authKeyboardInteractive);
}
if (logger.isDebugEnabled()) {
final List<String> methods = authMethods.stream().map(AuthMethod::getName).collect(Collectors.toList());
logger.debug("Authentication Methods Configured {}", methods);
}
return authMethods;
}
private KeyProvider getKeyProvider(final SSHClient client, final String privateKeyLocation, final String privateKeyPassphrase) {
final KeyFormat keyFormat = getKeyFormat(privateKeyLocation);
logger.debug("Loading Private Key File [{}] Format [{}]", privateKeyLocation, keyFormat);
try {
return privateKeyPassphrase == null ? client.loadKeys(privateKeyLocation) : client.loadKeys(privateKeyLocation, privateKeyPassphrase);
} catch (final IOException e) {
throw new ProcessException(String.format("Loading Private Key File [%s] Format [%s] Failed", privateKeyLocation, keyFormat), e);
}
}
private KeyFormat getKeyFormat(final String privateKeyLocation) {
try {
final File privateKeyFile = new File(privateKeyLocation);
return KeyProviderUtil.detectKeyFileFormat(privateKeyFile);
} catch (final IOException e) {
throw new ProcessException(String.format("Reading Private Key File [%s] Format Failed", privateKeyLocation), e);
}
}
private PasswordFinder getPasswordFinder(final String password) {
return PasswordUtils.createOneOff(password.toCharArray());
}
}

View File

@ -81,7 +81,7 @@ public class TestListSFTP {
sftpServer.deleteAllFilesAndDirectories();
}
@Test
@Test(timeout = 5000)
public void testListingWhileConcurrentlyWritingIntoMultipleDirectories() throws Exception {
AtomicInteger fileCounter = new AtomicInteger(1);

View File

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
import org.junit.jupiter.api.Test;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.Socket;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
public class ProxySocketFactoryTest {
@Test
public void testCreateSocketNotConnected() {
final Proxy.Type proxyType = Proxy.Type.SOCKS;
final InetSocketAddress proxyAddress = new InetSocketAddress("localhost", 1080);
final Proxy proxy = new Proxy(proxyType, proxyAddress);
final ProxySocketFactory socketFactory = new ProxySocketFactory(proxy);
final Socket socket = socketFactory.createSocket();
assertNotNull(socket);
assertFalse(socket.isConnected());
}
}

View File

@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
import org.apache.nifi.components.PropertyValue;
import org.apache.nifi.context.PropertyContext;
import org.apache.nifi.remote.io.socket.NetworkUtils;
import org.apache.nifi.util.MockPropertyValue;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.Collections;
import static org.apache.nifi.processors.standard.util.FTPTransfer.PORT;
import static org.apache.nifi.processors.standard.util.FTPTransfer.PROXY_TYPE;
import static org.apache.nifi.processors.standard.util.FTPTransfer.PROXY_HOST;
import static org.apache.nifi.processors.standard.util.FTPTransfer.PROXY_PORT;
import static org.apache.nifi.processors.standard.util.FTPTransfer.HTTP_PROXY_USERNAME;
import static org.apache.nifi.processors.standard.util.FTPTransfer.HTTP_PROXY_PASSWORD;
import static org.apache.nifi.processors.standard.util.FTPTransfer.PROXY_TYPE_DIRECT;
import static org.apache.nifi.processors.standard.util.FileTransfer.CONNECTION_TIMEOUT;
import static org.apache.nifi.processors.standard.util.FileTransfer.DATA_TIMEOUT;
import static org.apache.nifi.processors.standard.util.FileTransfer.HOSTNAME;
import static org.apache.nifi.processors.standard.util.FileTransfer.USE_COMPRESSION;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.CIPHERS_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.HOST_KEY_FILE;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.KEY_ALGORITHMS_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.KEY_EXCHANGE_ALGORITHMS_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.MESSAGE_AUTHENTICATION_CODES_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.PROXY_CONFIGURATION_SERVICE;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.STRICT_HOST_KEY_CHECKING;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StandardSSHClientProviderTest {
private static final PropertyValue NULL_PROPERTY_VALUE = new MockPropertyValue(null);
private static final PropertyValue BOOLEAN_TRUE_PROPERTY_VALUE = new MockPropertyValue(Boolean.TRUE.toString());
private static final PropertyValue BOOLEAN_FALSE_PROPERTY_VALUE = new MockPropertyValue(Boolean.FALSE.toString());
private static final PropertyValue TIMEOUT_PROPERTY_VALUE = new MockPropertyValue("2 s");
private static final String LOCALHOST = "localhost";
private static final PropertyValue HOSTNAME_PROPERTY = new MockPropertyValue(LOCALHOST);
@Mock
private PropertyContext context;
private StandardSSHClientProvider provider;
private int port;
@BeforeEach
public void setProvider() {
when(context.getProperty(any())).thenReturn(BOOLEAN_TRUE_PROPERTY_VALUE);
when(context.getProperty(CIPHERS_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(KEY_ALGORITHMS_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(KEY_EXCHANGE_ALGORITHMS_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(MESSAGE_AUTHENTICATION_CODES_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(CONNECTION_TIMEOUT)).thenReturn(TIMEOUT_PROPERTY_VALUE);
when(context.getProperty(DATA_TIMEOUT)).thenReturn(TIMEOUT_PROPERTY_VALUE);
when(context.getProperty(STRICT_HOST_KEY_CHECKING)).thenReturn(BOOLEAN_FALSE_PROPERTY_VALUE);
when(context.getProperty(HOST_KEY_FILE)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(USE_COMPRESSION)).thenReturn(BOOLEAN_FALSE_PROPERTY_VALUE);
when(context.getProperty(PROXY_CONFIGURATION_SERVICE)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(PROXY_TYPE)).thenReturn(new MockPropertyValue(PROXY_TYPE_DIRECT));
when(context.getProperty(PROXY_HOST)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(PROXY_PORT)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(HTTP_PROXY_USERNAME)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(HTTP_PROXY_PASSWORD)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(HOSTNAME)).thenReturn(HOSTNAME_PROPERTY);
port = NetworkUtils.getAvailableTcpPort();
when(context.getProperty(PORT)).thenReturn(new MockPropertyValue(Integer.toString(port)));
provider = new StandardSSHClientProvider();
}
@Test
public void testGetClientConfigurationException() {
final ClientConfigurationException exception = assertThrows(ClientConfigurationException.class, () -> provider.getClient(context, Collections.emptyMap()));
assertTrue(exception.getMessage().contains(LOCALHOST));
assertTrue(exception.getMessage().contains(Integer.toString(port)));
}
@Test
public void testGetClientConnectException() {
final ClientConnectException exception = assertThrows(ClientConnectException.class, () -> provider.getClient(context, Collections.emptyMap()));
assertTrue(exception.getMessage().contains(LOCALHOST));
assertTrue(exception.getMessage().contains(Integer.toString(port)));
}
}

View File

@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.ssh;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.transport.cipher.Cipher;
import org.apache.nifi.components.PropertyValue;
import org.apache.nifi.context.PropertyContext;
import org.apache.nifi.util.MockPropertyValue;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.Iterator;
import java.util.List;
import java.util.UUID;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.CIPHERS_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.KEY_ALGORITHMS_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.KEY_EXCHANGE_ALGORITHMS_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.MESSAGE_AUTHENTICATION_CODES_ALLOWED;
import static org.apache.nifi.processors.standard.util.SFTPTransfer.USE_KEEPALIVE_ON_TIMEOUT;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StandardSSHConfigProviderTest {
private static final Config DEFAULT_CONFIG = new DefaultConfig();
private static final String FIRST_ALLOWED_CIPHER = "aes128-ctr";
private static final String SECOND_ALLOWED_CIPHER = "aes256-cbc";
private static final String ALLOWED_CIPHERS = String.format("%s,%s", FIRST_ALLOWED_CIPHER, SECOND_ALLOWED_CIPHER);
private static final PropertyValue NULL_PROPERTY_VALUE = new MockPropertyValue(null);
private static final int KEEP_ALIVE_ENABLED_INTERVAL = 5;
private static final int KEEP_ALIVE_DISABLED_INTERVAL = 0;
private static final String IDENTIFIER = UUID.randomUUID().toString();
@Mock
private PropertyContext context;
@Mock
private ConnectionImpl connection;
@Mock
private Transport transport;
private StandardSSHConfigProvider provider;
@BeforeEach
public void setProvider() {
when(transport.getConfig()).thenReturn(DEFAULT_CONFIG);
when(connection.getTransport()).thenReturn(transport);
provider = new StandardSSHConfigProvider();
}
@Test
public void testGetConfigDefaultValues() {
when(context.getProperty(USE_KEEPALIVE_ON_TIMEOUT)).thenReturn(new MockPropertyValue(Boolean.TRUE.toString()));
when(context.getProperty(CIPHERS_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(KEY_ALGORITHMS_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(KEY_EXCHANGE_ALGORITHMS_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(MESSAGE_AUTHENTICATION_CODES_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
final Config config = provider.getConfig(IDENTIFIER, context);
assertNotNull(config);
final KeepAliveProvider keepAliveProvider = config.getKeepAliveProvider();
final KeepAlive keepAlive = keepAliveProvider.provide(connection);
assertEquals(KEEP_ALIVE_ENABLED_INTERVAL, keepAlive.getKeepAliveInterval());
assertNamedEquals(DEFAULT_CONFIG.getCipherFactories(), config.getCipherFactories());
assertNamedEquals(DEFAULT_CONFIG.getKeyAlgorithms(), config.getKeyAlgorithms());
assertNamedEquals(DEFAULT_CONFIG.getKeyExchangeFactories(), config.getKeyExchangeFactories());
assertNamedEquals(DEFAULT_CONFIG.getMACFactories(), config.getMACFactories());
}
@Test
public void testGetConfigCiphersAllowedKeepAliveDisabled() {
when(context.getProperty(USE_KEEPALIVE_ON_TIMEOUT)).thenReturn(new MockPropertyValue(Boolean.FALSE.toString()));
when(context.getProperty(CIPHERS_ALLOWED)).thenReturn(new MockPropertyValue(ALLOWED_CIPHERS));
when(context.getProperty(KEY_ALGORITHMS_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(KEY_EXCHANGE_ALGORITHMS_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
when(context.getProperty(MESSAGE_AUTHENTICATION_CODES_ALLOWED)).thenReturn(NULL_PROPERTY_VALUE);
final Config config = provider.getConfig(IDENTIFIER, context);
assertNotNull(config);
final KeepAliveProvider keepAliveProvider = config.getKeepAliveProvider();
final KeepAlive keepAlive = keepAliveProvider.provide(connection);
assertEquals(KEEP_ALIVE_DISABLED_INTERVAL, keepAlive.getKeepAliveInterval());
final Iterator<Factory.Named<Cipher>> cipherFactories = config.getCipherFactories().iterator();
assertTrue(cipherFactories.hasNext());
final Factory.Named<Cipher> firstCipherFactory = cipherFactories.next();
assertEquals(FIRST_ALLOWED_CIPHER, firstCipherFactory.getName());
final Factory.Named<Cipher> secondCipherFactory = cipherFactories.next();
assertEquals(SECOND_ALLOWED_CIPHER, secondCipherFactory.getName());
assertFalse(cipherFactories.hasNext());
assertNamedEquals(DEFAULT_CONFIG.getKeyAlgorithms(), config.getKeyAlgorithms());
assertNamedEquals(DEFAULT_CONFIG.getKeyExchangeFactories(), config.getKeyExchangeFactories());
assertNamedEquals(DEFAULT_CONFIG.getMACFactories(), config.getMACFactories());
}
private <T> void assertNamedEquals(final List<Factory.Named<T>> expected, final List<Factory.Named<T>> actual) {
assertEquals(expected.size(), actual.size());
final Iterator<Factory.Named<T>> expectedValues = expected.iterator();
final Iterator<Factory.Named<T>> actualValues = actual.iterator();
while (expectedValues.hasNext()) {
final Factory.Named<?> expectedValue = expectedValues.next();
final Factory.Named<?> actualValue = actualValues.next();
assertEquals(expectedValue.getName(), actualValue.getName());
}
}
}

View File

@ -16,47 +16,21 @@
*/
package org.apache.nifi.processors.standard.util;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.sftp.Response;
import net.schmizz.sshj.sftp.SFTPClient;
import net.schmizz.sshj.sftp.SFTPException;
import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.userauth.method.AuthKeyboardInteractive;
import net.schmizz.sshj.userauth.method.AuthMethod;
import net.schmizz.sshj.userauth.method.AuthPassword;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.mock.MockComponentLogger;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.MockPropertyContext;
import org.apache.nifi.util.MockPropertyValue;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
@ -68,8 +42,6 @@ import static org.mockito.Mockito.when;
public class TestSFTPTransfer {
private static final Logger logger = LoggerFactory.getLogger(TestSFTPTransfer.class);
private SFTPTransfer createSftpTransfer(ProcessContext processContext, SFTPClient sftpClient) {
final ComponentLog componentLog = mock(ComponentLog.class);
return new SFTPTransfer(processContext, componentLog) {
@ -208,8 +180,6 @@ public class TestSFTPTransfer {
if (cnt == 0) {
// If the parent dir does not exist, no such file exception is thrown.
throw new SFTPException(Response.StatusCode.NO_SUCH_FILE, "Failure");
} else {
logger.info("Created the dir successfully for the 2nd time");
}
return true;
}).when(sftpClient).mkdir(eq("/dir1/dir2/dir3"));
@ -267,104 +237,4 @@ public class TestSFTPTransfer {
verify(sftpClient, times(0)).stat(eq("/dir1/dir2/dir3"));
verify(sftpClient).mkdir(eq("/dir1/dir2/dir3")); // dir3 was created blindly.
}
@Test
public void testRestrictSSHOptions() {
Map<PropertyDescriptor, String> propertyDescriptorValues = new HashMap<>();
DefaultConfig defaultConfig = new DefaultConfig();
String allowedMac = defaultConfig.getMACFactories().stream().map(Factory.Named::getName).collect(Collectors.toList()).get(0);
String allowedKeyAlgorithm = defaultConfig.getKeyAlgorithms().stream().map(Factory.Named::getName).collect(Collectors.toList()).get(0);
String allowedKeyExchangeAlgorithm = defaultConfig.getKeyExchangeFactories().stream().map(Factory.Named::getName).collect(Collectors.toList()).get(0);
String allowedCipher = defaultConfig.getCipherFactories().stream().map(Factory.Named::getName).collect(Collectors.toList()).get(0);
propertyDescriptorValues.put(SFTPTransfer.MESSAGE_AUTHENTICATION_CODES_ALLOWED, allowedMac);
propertyDescriptorValues.put(SFTPTransfer.CIPHERS_ALLOWED, allowedCipher);
propertyDescriptorValues.put(SFTPTransfer.KEY_ALGORITHMS_ALLOWED, allowedKeyAlgorithm);
propertyDescriptorValues.put(SFTPTransfer.KEY_EXCHANGE_ALGORITHMS_ALLOWED, allowedKeyExchangeAlgorithm);
MockPropertyContext mockPropertyContext = new MockPropertyContext(propertyDescriptorValues);
SFTPTransfer sftpTransfer = new SFTPTransfer(mockPropertyContext, new MockComponentLogger());
sftpTransfer.updateConfigAlgorithms(defaultConfig);
assertEquals(1, defaultConfig.getCipherFactories().size());
assertEquals(1, defaultConfig.getKeyAlgorithms().size());
assertEquals(1, defaultConfig.getKeyExchangeFactories().size());
assertEquals(1, defaultConfig.getMACFactories().size());
assertEquals(allowedCipher, defaultConfig.getCipherFactories().get(0).getName());
assertEquals(allowedKeyAlgorithm, defaultConfig.getKeyAlgorithms().get(0).getName());
assertEquals(allowedKeyExchangeAlgorithm, defaultConfig.getKeyExchangeFactories().get(0).getName());
assertEquals(allowedMac, defaultConfig.getMACFactories().get(0).getName());
}
@Test
public void testGetAuthMethodsPassword() {
final String password = UUID.randomUUID().toString();
final ProcessContext processContext = mock(ProcessContext.class);
when(processContext.getProperty(SFTPTransfer.PASSWORD)).thenReturn(new MockPropertyValue(password));
when(processContext.getProperty(SFTPTransfer.PRIVATE_KEY_PATH)).thenReturn(new MockPropertyValue(null));
final SFTPClient sftpClient = mock(SFTPClient.class);
final SFTPTransfer sftpTransfer = createSftpTransfer(processContext, sftpClient);
final SSHClient sshClient = new SSHClient();
final List<AuthMethod> authMethods = sftpTransfer.getAuthMethods(sshClient, null);
assertFalse("Authentication Methods not found", authMethods.isEmpty());
final Optional<AuthMethod> authPassword = authMethods.stream().filter(authMethod -> authMethod instanceof AuthPassword).findFirst();
assertTrue("Password Authentication not found", authPassword.isPresent());
final Optional<AuthMethod> authKeyboardInteractive = authMethods.stream().filter(authMethod -> authMethod instanceof AuthKeyboardInteractive).findFirst();
assertTrue("Keyboard Interactive Authentication not found", authKeyboardInteractive.isPresent());
}
@Test
public void testGetAuthMethodsPrivateKeyLoadFailed() throws IOException {
final File privateKeyFile = File.createTempFile(TestSFTPTransfer.class.getSimpleName(), ".key");
privateKeyFile.deleteOnExit();
final ProcessContext processContext = mock(ProcessContext.class);
when(processContext.getProperty(SFTPTransfer.PASSWORD)).thenReturn(new MockPropertyValue(null));
when(processContext.getProperty(SFTPTransfer.PRIVATE_KEY_PATH)).thenReturn(new MockPropertyValue(privateKeyFile.getAbsolutePath()));
when(processContext.getProperty(SFTPTransfer.PRIVATE_KEY_PASSPHRASE)).thenReturn(new MockPropertyValue(null));
final SFTPClient sftpClient = mock(SFTPClient.class);
final SFTPTransfer sftpTransfer = createSftpTransfer(processContext, sftpClient);
final SSHClient sshClient = new SSHClient();
assertThrows(ProcessException.class, () -> sftpTransfer.getAuthMethods(sshClient, null));
}
@Test
public void testGetKeepAliveProviderEnabled() {
final ProcessContext processContext = mock(ProcessContext.class);
when(processContext.getProperty(SFTPTransfer.USE_KEEPALIVE_ON_TIMEOUT)).thenReturn(new MockPropertyValue(Boolean.TRUE.toString()));
final KeepAlive keepAlive = getKeepAlive(processContext);
assertNotSame("Keep Alive Interval not configured", 0, keepAlive.getKeepAliveInterval());
}
@Test
public void testGetKeepAliveProviderDisabled() {
final ProcessContext processContext = mock(ProcessContext.class);
when(processContext.getProperty(SFTPTransfer.USE_KEEPALIVE_ON_TIMEOUT)).thenReturn(new MockPropertyValue(Boolean.FALSE.toString()));
final KeepAlive keepAlive = getKeepAlive(processContext);
assertEquals("Keep Alive Interval configured", 0, keepAlive.getKeepAliveInterval());
}
private KeepAlive getKeepAlive(final ProcessContext processContext) {
final SFTPClient sftpClient = mock(SFTPClient.class);
final SFTPTransfer sftpTransfer = createSftpTransfer(processContext, sftpClient);
final Transport transport = mock(Transport.class);
when(transport.getConfig()).thenReturn(new DefaultConfig());
final KeepAliveProvider mockKeepAliveProvider = mock(KeepAliveProvider.class);
final ConnectionImpl connection = new ConnectionImpl(transport, mockKeepAliveProvider);
final KeepAliveProvider keepAliveProvider = sftpTransfer.getKeepAliveProvider();
return keepAliveProvider.provide(connection);
}
}