allow execChannel to survive sshclient.disconnect, and not be bound by sessionTimeout

This commit is contained in:
Adrian Cole 2012-03-10 13:25:12 -08:00
parent 06ab36ae76
commit 755485537b
10 changed files with 520 additions and 184 deletions

View File

@ -55,15 +55,19 @@ public interface SshClient {
/** /**
* Execute a process and block until it is complete * Execute a process and block until it is complete
* *
* @param command command line to invoke * @param command
* command line to invoke
* @return output of the command * @return output of the command
*/ */
ExecResponse exec(String command); ExecResponse exec(String command);
/** /**
* Execute a process and allow the user to interact with it * Execute a process and allow the user to interact with it. Note that this will allow the
* session to exist indefinitely, and its connection is not closed when {@link #disconnect()} is
* called.
* *
* @param command command line to invoke * @param command
* command line to invoke
* @return reference to the running process * @return reference to the running process
* @since 1.5.0 * @since 1.5.0
*/ */

View File

@ -34,7 +34,6 @@ import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.ConnectException; import java.net.ConnectException;
import java.util.Arrays;
import javax.annotation.PreDestroy; import javax.annotation.PreDestroy;
import javax.annotation.Resource; import javax.annotation.Resource;
@ -44,6 +43,7 @@ import org.apache.commons.io.input.ProxyInputStream;
import org.apache.commons.io.output.ByteArrayOutputStream; import org.apache.commons.io.output.ByteArrayOutputStream;
import org.jclouds.compute.domain.ExecChannel; import org.jclouds.compute.domain.ExecChannel;
import org.jclouds.compute.domain.ExecResponse; import org.jclouds.compute.domain.ExecResponse;
import org.jclouds.domain.LoginCredentials;
import org.jclouds.http.handlers.BackoffLimitedRetryHandler; import org.jclouds.http.handlers.BackoffLimitedRetryHandler;
import org.jclouds.io.Payload; import org.jclouds.io.Payload;
import org.jclouds.io.Payloads; import org.jclouds.io.Payloads;
@ -52,7 +52,6 @@ import org.jclouds.net.IPSocket;
import org.jclouds.rest.AuthorizationException; import org.jclouds.rest.AuthorizationException;
import org.jclouds.ssh.SshClient; import org.jclouds.ssh.SshClient;
import org.jclouds.ssh.SshException; import org.jclouds.ssh.SshException;
import org.jclouds.util.CredentialUtils;
import org.jclouds.util.Strings2; import org.jclouds.util.Strings2;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
@ -61,10 +60,10 @@ import com.google.common.base.Predicates;
import com.google.common.base.Splitter; import com.google.common.base.Splitter;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import com.google.common.io.Closeables; import com.google.common.io.Closeables;
import com.google.common.net.HostAndPort;
import com.google.inject.Inject; import com.google.inject.Inject;
import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.ChannelSftp; import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException; import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session; import com.jcraft.jsch.Session;
@ -92,10 +91,6 @@ public class JschSshClient implements SshClient {
} }
} }
private final String host;
private final int port;
private final String username;
private final String password;
private final String toString; private final String toString;
@Inject(optional = true) @Inject(optional = true)
@ -121,31 +116,31 @@ public class JschSshClient implements SshClient {
@Named("jclouds.ssh") @Named("jclouds.ssh")
protected Logger logger = Logger.NULL; protected Logger logger = Logger.NULL;
private Session session;
private final byte[] privateKey;
final byte[] emptyPassPhrase = new byte[0];
private final int timeout;
private final BackoffLimitedRetryHandler backoffLimitedRetryHandler; private final BackoffLimitedRetryHandler backoffLimitedRetryHandler;
public JschSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout, final SessionConnection sessionConnection;
String username, String password, byte[] privateKey) { final String user;
final String host;
public JschSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket,
LoginCredentials loginCredentials, int timeout) {
this.user = checkNotNull(loginCredentials, "loginCredentials").getUser();
this.host = checkNotNull(socket, "socket").getAddress(); this.host = checkNotNull(socket, "socket").getAddress();
checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort()); checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort());
checkArgument(password != null || privateKey != null, "you must specify a password or a key"); checkArgument(loginCredentials.getPassword() != null || loginCredentials.getPrivateKey() != null,
this.port = socket.getPort(); "you must specify a password or a key");
this.username = checkNotNull(username, "username");
this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler"); this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler");
this.timeout = timeout; if (loginCredentials.getPrivateKey() == null) {
this.password = password; this.toString = String.format("%s:pw[%s]@%s:%d", loginCredentials.getUser(), hex(md5(loginCredentials
this.privateKey = privateKey; .getPassword().getBytes())), host, socket.getPort());
if (privateKey == null) {
this.toString = String.format("%s:pw[%s]@%s:%d", username, hex(md5(password.getBytes())), host, port);
} else { } else {
String fingerPrint = fingerprintPrivateKey(new String(privateKey)); String fingerPrint = fingerprintPrivateKey(loginCredentials.getPrivateKey());
String sha1 = sha1PrivateKey(new String(privateKey)); String sha1 = sha1PrivateKey(loginCredentials.getPrivateKey());
this.toString = String.format("%s:rsa[fingerprint(%s),sha1(%s)]@%s:%d", username, fingerPrint, sha1, host, this.toString = String.format("%s:rsa[fingerprint(%s),sha1(%s)]@%s:%d", loginCredentials.getUser(),
port); fingerPrint, sha1, host, socket.getPort());
} }
sessionConnection = SessionConnection.builder().hostAndPort(HostAndPort.fromParts(host, socket.getPort())).loginCredentials(
loginCredentials).connectTimeout(timeout).sessionTimeout(timeout).build();
} }
@Override @Override
@ -154,7 +149,8 @@ public class JschSshClient implements SshClient {
} }
private void checkConnected() { private void checkConnected() {
checkState(session != null && session.isConnected(), String.format("(%s) Session not connected!", toString())); checkState(sessionConnection.getSession() != null && sessionConnection.getSession().isConnected(), String.format(
"(%s) Session not connected!", toString()));
} }
public static interface Connection<T> { public static interface Connection<T> {
@ -163,45 +159,6 @@ public class JschSshClient implements SshClient {
T create() throws Exception; T create() throws Exception;
} }
Connection<Session> sessionConnection = new Connection<Session>() {
@Override
public void clear() {
if (session != null && session.isConnected()) {
session.disconnect();
session = null;
}
}
@Override
public Session create() throws Exception {
JSch jsch = new JSch();
session = jsch.getSession(username, host, port);
if (timeout != 0)
session.setTimeout(timeout);
if (password != null) {
session.setPassword(password);
} else {
// jsch wipes out your private key
if (CredentialUtils.isPrivateKeyEncrypted(privateKey)) {
throw new IllegalArgumentException(
"JschSshClientModule does not support private keys that require a passphrase");
}
jsch.addIdentity(username, Arrays.copyOf(privateKey, privateKey.length), null, emptyPassPhrase);
}
java.util.Properties config = new java.util.Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect(timeout);
return session;
}
@Override
public String toString() {
return String.format("Session(timeout=%d)", timeout);
}
};
protected <T, C extends Connection<T>> T acquire(C connection) { protected <T, C extends Connection<T>> T acquire(C connection) {
connection.clear(); connection.clear();
String errorMessage = String.format("(%s) error acquiring %s", toString(), connection); String errorMessage = String.format("(%s) error acquiring %s", toString(), connection);
@ -245,7 +202,7 @@ public class JschSshClient implements SshClient {
public ChannelSftp create() throws JSchException { public ChannelSftp create() throws JSchException {
checkConnected(); checkConnected();
String channel = "sftp"; String channel = "sftp";
sftp = (ChannelSftp) session.openChannel(channel); sftp = (ChannelSftp) sessionConnection.getSession().openChannel(channel);
sftp.connect(); sftp.connect();
return sftp; return sftp;
} }
@ -394,7 +351,7 @@ public class JschSshClient implements SshClient {
public ChannelExec create() throws Exception { public ChannelExec create() throws Exception {
checkConnected(); checkConnected();
String channel = "exec"; String channel = "exec";
executor = (ChannelExec) session.openChannel(channel); executor = (ChannelExec) sessionConnection.getSession().openChannel(channel);
executor.setPty(true); executor.setPty(true);
executor.setCommand(command); executor.setCommand(command);
ByteArrayOutputStream error = new ByteArrayOutputStream(); ByteArrayOutputStream error = new ByteArrayOutputStream();
@ -468,13 +425,13 @@ public class JschSshClient implements SshClient {
@Override @Override
public String getUsername() { public String getUsername() {
return this.username; return this.user;
} }
class ExecChannelConnection implements Connection<ExecChannel> { class ExecChannelConnection implements Connection<ExecChannel> {
private final String command; private final String command;
private ChannelExec executor = null; private ChannelExec executor = null;
private Session sessionConnection;
ExecChannelConnection(String command) { ExecChannelConnection(String command) {
this.command = checkNotNull(command, "command"); this.command = checkNotNull(command, "command");
@ -484,13 +441,16 @@ public class JschSshClient implements SshClient {
public void clear() { public void clear() {
if (executor != null) if (executor != null)
executor.disconnect(); executor.disconnect();
if (sessionConnection != null)
sessionConnection.disconnect();
} }
@Override @Override
public ExecChannel create() throws Exception { public ExecChannel create() throws Exception {
checkConnected(); this.sessionConnection = acquire(SessionConnection.builder().fromSessionConnection(
JschSshClient.this.sessionConnection).sessionTimeout(0).build());
String channel = "exec"; String channel = "exec";
executor = (ChannelExec) session.openChannel(channel); executor = (ChannelExec) sessionConnection.openChannel(channel);
executor.setCommand(command); executor.setCommand(command);
ByteArrayOutputStream error = new ByteArrayOutputStream(); ByteArrayOutputStream error = new ByteArrayOutputStream();
executor.setErrStream(error); executor.setErrStream(error);
@ -520,7 +480,6 @@ public class JschSshClient implements SshClient {
} }
}; };
@Override @Override
public ExecChannel execChannel(String command) { public ExecChannel execChannel(String command) {
return acquire(new ExecChannelConnection(command)); return acquire(new ExecChannelConnection(command));

View File

@ -0,0 +1,201 @@
/**
* Licensed to jclouds, Inc. (jclouds) under one or more
* contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. jclouds licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.jclouds.ssh.jsch;
import static com.google.common.base.Objects.equal;
import java.util.Arrays;
import org.jclouds.domain.LoginCredentials;
import org.jclouds.ssh.jsch.JschSshClient.Connection;
import org.jclouds.util.CredentialUtils;
import com.google.common.base.Objects;
import com.google.common.net.HostAndPort;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
public class SessionConnection implements Connection<Session> {
public static Builder builder() {
return new Builder();
}
public static class Builder {
protected HostAndPort hostAndPort;
protected LoginCredentials loginCredentials;
protected int connectTimeout;
protected int sessionTimeout;
/**
* @see SessionConnection#getHostAndPort()
*/
public Builder hostAndPort(HostAndPort hostAndPort) {
this.hostAndPort = hostAndPort;
return this;
}
/**
* @see SessionConnection#getLoginCredentials()
*/
public Builder loginCredentials(LoginCredentials loginCredentials) {
this.loginCredentials = loginCredentials;
return this;
}
/**
* @see SessionConnection#getConnectTimeout()
*/
public Builder connectTimeout(int connectTimeout) {
this.connectTimeout = connectTimeout;
return this;
}
/**
* @see SessionConnection#getConnectTimeout()
*/
public Builder sessionTimeout(int sessionTimeout) {
this.sessionTimeout = sessionTimeout;
return this;
}
public SessionConnection build() {
return new SessionConnection(hostAndPort, loginCredentials, connectTimeout, sessionTimeout);
}
protected Builder fromSessionConnection(SessionConnection in) {
return hostAndPort(in.getHostAndPort()).connectTimeout(in.getConnectTimeout()).loginCredentials(
in.getLoginCredentials());
}
}
private SessionConnection(HostAndPort hostAndPort, LoginCredentials loginCredentials, int connectTimeout,
int sessionTimeout) {
this.hostAndPort = hostAndPort;
this.loginCredentials = loginCredentials;
this.connectTimeout = connectTimeout;
this.sessionTimeout = sessionTimeout;
}
private static final byte[] emptyPassPhrase = new byte[0];
private final HostAndPort hostAndPort;
private final LoginCredentials loginCredentials;
private final int connectTimeout;
private final int sessionTimeout;
private transient Session session;
@Override
public void clear() {
if (session != null && session.isConnected()) {
session.disconnect();
session = null;
}
}
@Override
public Session create() throws Exception {
JSch jsch = new JSch();
session = jsch
.getSession(loginCredentials.getUser(), hostAndPort.getHostText(), hostAndPort.getPortOrDefault(22));
if (sessionTimeout != 0)
session.setTimeout(sessionTimeout);
if (loginCredentials.getPrivateKey() == null) {
session.setPassword(loginCredentials.getPassword());
} else {
byte[] privateKey = loginCredentials.getPrivateKey().getBytes();
if (CredentialUtils.isPrivateKeyEncrypted(privateKey)) {
throw new IllegalArgumentException(
"JschSshClientModule does not support private keys that require a passphrase");
}
jsch.addIdentity(loginCredentials.getUser(), Arrays.copyOf(privateKey, privateKey.length), null,
emptyPassPhrase);
}
java.util.Properties config = new java.util.Properties();
config.put("StrictHostKeyChecking", "no");
session.setConfig(config);
session.connect(connectTimeout);
return session;
}
/**
* @return host and port, where port if not present defaults to {@code 22}
*/
public HostAndPort getHostAndPort() {
return hostAndPort;
}
/**
*
* @return login used in this session
*/
public LoginCredentials getLoginCredentials() {
return loginCredentials;
}
/**
*
* @return how long to wait for the initial connection to be made
*/
public int getConnectTimeout() {
return connectTimeout;
}
/**
*
* @return how long to keep the session open, or {@code 0} for indefinitely
*/
public int getSessionTimeout() {
return sessionTimeout;
}
/**
*
* @return the current session or {@code null} if not connected
*/
public Session getSession() {
return session;
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
SessionConnection that = SessionConnection.class.cast(o);
return equal(this.hostAndPort, that.hostAndPort) && equal(this.loginCredentials, that.loginCredentials)
&& equal(this.session, that.session);
}
@Override
public int hashCode() {
return Objects.hashCode(hostAndPort, loginCredentials, session);
}
@Override
public String toString() {
return Objects.toStringHelper("").add("hostAndPort", hostAndPort).add("loginUser", loginCredentials.getUser())
.add("session", session != null ? session.hashCode() : null).add("connectTimeout", connectTimeout).add(
"sessionTimeout", sessionTimeout).toString();
}
}

View File

@ -65,9 +65,7 @@ public class JschSshClientModule extends AbstractModule {
@Override @Override
public SshClient create(IPSocket socket, LoginCredentials credentials) { public SshClient create(IPSocket socket, LoginCredentials credentials) {
SshClient client = new JschSshClient(backoffLimitedRetryHandler, socket, timeout, credentials.getUser(), SshClient client = new JschSshClient(backoffLimitedRetryHandler, socket, credentials, timeout);
(credentials.getPrivateKey() == null) ? credentials.getPassword() : null,
credentials.getPrivateKey() != null ? credentials.getPrivateKey().getBytes() : null);
injector.injectMembers(client);// add logger injector.injectMembers(client);// add logger
return client; return client;
} }

View File

@ -170,8 +170,10 @@ public class JschSshClientLiveTest {
: sshHost); : sshHost);
} }
public void testExecChannelTakesStdinAndNoEchoOfCharsInOuput() throws IOException { public void testExecChannelTakesStdinAndNoEchoOfCharsInOuputAndOutlivesClient() throws IOException {
ExecChannel response = setupClient().execChannel("cat <<EOF"); SshClient client = setupClient();
ExecChannel response = client.execChannel("cat <<EOF");
client.disconnect();
assertEquals(response.getExitStatus().get(), null); assertEquals(response.getExitStatus().get(), null);
try { try {
PrintStream printStream = new PrintStream(response.getInput()); PrintStream printStream = new PrintStream(response.getInput());

View File

@ -0,0 +1,209 @@
/**
* Licensed to jclouds, Inc. (jclouds) under one or more
* contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. jclouds licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.jclouds.sshj;
import static com.google.common.base.Objects.equal;
import java.io.IOException;
import javax.annotation.Resource;
import javax.inject.Named;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
import org.jclouds.domain.LoginCredentials;
import org.jclouds.logging.Logger;
import org.jclouds.sshj.SshjSshClient.Connection;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Objects;
import com.google.common.net.HostAndPort;
public class SSHClientConnection implements Connection<SSHClient> {
public static Builder builder() {
return new Builder();
}
public static class Builder {
protected HostAndPort hostAndPort;
protected LoginCredentials loginCredentials;
protected int connectTimeout;
protected int sessionTimeout;
/**
* @see SSHClientConnection#getHostAndPort()
*/
public Builder hostAndPort(HostAndPort hostAndPort) {
this.hostAndPort = hostAndPort;
return this;
}
/**
* @see SSHClientConnection#getLoginCredentials()
*/
public Builder loginCredentials(LoginCredentials loginCredentials) {
this.loginCredentials = loginCredentials;
return this;
}
/**
* @see SSHClientConnection#getConnectTimeout()
*/
public Builder connectTimeout(int connectTimeout) {
this.connectTimeout = connectTimeout;
return this;
}
/**
* @see SSHClientConnection#getConnectTimeout()
*/
public Builder sessionTimeout(int sessionTimeout) {
this.sessionTimeout = sessionTimeout;
return this;
}
public SSHClientConnection build() {
return new SSHClientConnection(hostAndPort, loginCredentials, connectTimeout, sessionTimeout);
}
protected Builder fromSSHClientConnection(SSHClientConnection in) {
return hostAndPort(in.getHostAndPort()).connectTimeout(in.getConnectTimeout()).loginCredentials(
in.getLoginCredentials());
}
}
private SSHClientConnection(HostAndPort hostAndPort, LoginCredentials loginCredentials, int connectTimeout,
int sessionTimeout) {
this.hostAndPort = hostAndPort;
this.loginCredentials = loginCredentials;
this.connectTimeout = connectTimeout;
this.sessionTimeout = sessionTimeout;
}
@Resource
@Named("jclouds.ssh")
protected Logger logger = Logger.NULL;
private final HostAndPort hostAndPort;
private final LoginCredentials loginCredentials;
private final int connectTimeout;
private final int sessionTimeout;
@VisibleForTesting
transient SSHClient ssh;
@Override
public void clear() {
if (ssh != null && ssh.isConnected()) {
try {
ssh.disconnect();
} catch (IOException e) {
logger.debug("<< exception disconnecting from %s: %s", e, e.getMessage());
}
ssh = null;
}
}
@Override
public SSHClient create() throws Exception {
ssh = new net.schmizz.sshj.SSHClient();
ssh.addHostKeyVerifier(new PromiscuousVerifier());
if (connectTimeout != 0) {
ssh.setConnectTimeout(connectTimeout);
}
if (sessionTimeout != 0) {
ssh.setTimeout(sessionTimeout);
}
ssh.connect(hostAndPort.getHostText(), hostAndPort.getPortOrDefault(22));
if (loginCredentials.getPassword() != null) {
ssh.authPassword(loginCredentials.getUser(), loginCredentials.getPassword());
} else {
OpenSSHKeyFile key = new OpenSSHKeyFile();
key.init(loginCredentials.getPrivateKey(), null);
ssh.authPublickey(loginCredentials.getUser(), key);
}
return ssh;
}
/**
* @return host and port, where port if not present defaults to {@code 22}
*/
public HostAndPort getHostAndPort() {
return hostAndPort;
}
/**
*
* @return login used in this ssh
*/
public LoginCredentials getLoginCredentials() {
return loginCredentials;
}
/**
*
* @return how long to wait for the initial connection to be made
*/
public int getConnectTimeout() {
return connectTimeout;
}
/**
*
* @return how long to keep the ssh open, or {@code 0} for indefinitely
*/
public int getSessionTimeout() {
return sessionTimeout;
}
/**
*
* @return the current ssh or {@code null} if not connected
*/
public SSHClient getSSHClient() {
return ssh;
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
SSHClientConnection that = SSHClientConnection.class.cast(o);
return equal(this.hostAndPort, that.hostAndPort) && equal(this.loginCredentials, that.loginCredentials)
&& equal(this.ssh, that.ssh);
}
@Override
public int hashCode() {
return Objects.hashCode(hostAndPort, loginCredentials, ssh);
}
@Override
public String toString() {
return Objects.toStringHelper("").add("hostAndPort", hostAndPort).add("loginUser", loginCredentials.getUser())
.add("ssh", ssh != null ? ssh.hashCode() : null).add("connectTimeout", connectTimeout).add(
"sessionTimeout", sessionTimeout).toString();
}
}

View File

@ -52,14 +52,13 @@ import net.schmizz.sshj.connection.channel.direct.Session.Command;
import net.schmizz.sshj.sftp.SFTPClient; import net.schmizz.sshj.sftp.SFTPClient;
import net.schmizz.sshj.sftp.SFTPException; import net.schmizz.sshj.sftp.SFTPException;
import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import net.schmizz.sshj.userauth.UserAuthException; import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
import net.schmizz.sshj.xfer.InMemorySourceFile; import net.schmizz.sshj.xfer.InMemorySourceFile;
import org.apache.commons.io.input.ProxyInputStream; import org.apache.commons.io.input.ProxyInputStream;
import org.jclouds.compute.domain.ExecChannel; import org.jclouds.compute.domain.ExecChannel;
import org.jclouds.compute.domain.ExecResponse; import org.jclouds.compute.domain.ExecResponse;
import org.jclouds.domain.LoginCredentials;
import org.jclouds.http.handlers.BackoffLimitedRetryHandler; import org.jclouds.http.handlers.BackoffLimitedRetryHandler;
import org.jclouds.io.Payload; import org.jclouds.io.Payload;
import org.jclouds.io.Payloads; import org.jclouds.io.Payloads;
@ -77,6 +76,7 @@ import com.google.common.base.Splitter;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import com.google.common.io.Closeables; import com.google.common.io.Closeables;
import com.google.common.net.HostAndPort;
import com.google.inject.Inject; import com.google.inject.Inject;
/** /**
@ -104,10 +104,6 @@ public class SshjSshClient implements SshClient {
} }
} }
private final String host;
private final int port;
private final String username;
private final String password;
private final String toString; private final String toString;
@Inject(optional = true) @Inject(optional = true)
@ -129,41 +125,42 @@ public class SshjSshClient implements SshClient {
@Named("jclouds.ssh.retry-predicate") @Named("jclouds.ssh.retry-predicate")
// NOTE cannot retry io exceptions, as SSHException is a part of the chain // NOTE cannot retry io exceptions, as SSHException is a part of the chain
private Predicate<Throwable> retryPredicate = or(instanceOf(ConnectionException.class), private Predicate<Throwable> retryPredicate = or(instanceOf(ConnectionException.class),
instanceOf(ConnectException.class), instanceOf(SocketTimeoutException.class), instanceOf(ConnectException.class), instanceOf(SocketTimeoutException.class),
instanceOf(TransportException.class), instanceOf(TransportException.class),
// safe to retry sftp exceptions as they are idempotent // safe to retry sftp exceptions as they are idempotent
instanceOf(SFTPException.class)); instanceOf(SFTPException.class));
@Resource @Resource
@Named("jclouds.ssh") @Named("jclouds.ssh")
protected Logger logger = Logger.NULL; protected Logger logger = Logger.NULL;
@VisibleForTesting @VisibleForTesting
SSHClient ssh; SSHClientConnection sshClientConnection;
private final byte[] privateKey;
final byte[] emptyPassPhrase = new byte[0]; final String user;
private final int timeoutMillis; final String host;
private final BackoffLimitedRetryHandler backoffLimitedRetryHandler; private final BackoffLimitedRetryHandler backoffLimitedRetryHandler;
public SshjSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout, public SshjSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket,
String username, String password, byte[] privateKey) { LoginCredentials loginCredentials, int timeout) {
this.user = checkNotNull(loginCredentials, "loginCredentials").getUser();
this.host = checkNotNull(socket, "socket").getAddress(); this.host = checkNotNull(socket, "socket").getAddress();
checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort()); checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort());
checkArgument(password != null || privateKey != null, "you must specify a password or a key"); checkArgument(loginCredentials.getPassword() != null || loginCredentials.getPrivateKey() != null,
this.port = socket.getPort(); "you must specify a password or a key");
this.username = checkNotNull(username, "username");
this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler"); this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler");
this.timeoutMillis = timeout; if (loginCredentials.getPrivateKey() == null) {
this.password = password; this.toString = String.format("%s:pw[%s]@%s:%d", loginCredentials.getUser(), hex(md5(loginCredentials
this.privateKey = privateKey; .getPassword().getBytes())), host, socket.getPort());
if (privateKey == null) {
this.toString = String.format("%s:pw[%s]@%s:%d", username, hex(md5(password.getBytes())), host, port);
} else { } else {
String fingerPrint = fingerprintPrivateKey(new String(privateKey)); String fingerPrint = fingerprintPrivateKey(loginCredentials.getPrivateKey());
String sha1 = sha1PrivateKey(new String(privateKey)); String sha1 = sha1PrivateKey(loginCredentials.getPrivateKey());
this.toString = String.format("%s:rsa[fingerprint(%s),sha1(%s)]@%s:%d", username, fingerPrint, sha1, host, this.toString = String.format("%s:rsa[fingerprint(%s),sha1(%s)]@%s:%d", loginCredentials.getUser(),
port); fingerPrint, sha1, host, socket.getPort());
} }
sshClientConnection = SSHClientConnection.builder().hostAndPort(HostAndPort.fromParts(host, socket.getPort()))
.loginCredentials(loginCredentials).connectTimeout(timeout).sessionTimeout(timeout).build();
} }
@Override @Override
@ -172,7 +169,8 @@ public class SshjSshClient implements SshClient {
} }
private void checkConnected() { private void checkConnected() {
checkState(ssh != null && ssh.isConnected(), String.format("(%s) ssh not connected!", toString())); checkState(sshClientConnection.ssh != null && sshClientConnection.ssh.isConnected(), String
.format("(%s) ssh not connected!", toString()));
} }
public static interface Connection<T> { public static interface Connection<T> {
@ -181,45 +179,6 @@ public class SshjSshClient implements SshClient {
T create() throws Exception; T create() throws Exception;
} }
Connection<net.schmizz.sshj.SSHClient> sshConnection = new Connection<net.schmizz.sshj.SSHClient>() {
@Override
public void clear() {
if (ssh != null && ssh.isConnected()) {
try {
ssh.disconnect();
} catch (IOException e) {
logger.warn(e, "<< exception disconnecting from %s: %s", e, e.getMessage());
}
ssh = null;
}
}
@Override
public net.schmizz.sshj.SSHClient create() throws Exception {
net.schmizz.sshj.SSHClient ssh = new net.schmizz.sshj.SSHClient();
ssh.addHostKeyVerifier(new PromiscuousVerifier());
if (timeoutMillis != 0) {
ssh.setTimeout(timeoutMillis);
ssh.setConnectTimeout(timeoutMillis);
}
ssh.connect(host, port);
if (password != null) {
ssh.authPassword(username, password);
} else {
OpenSSHKeyFile key = new OpenSSHKeyFile();
key.init(new String(privateKey), null);
ssh.authPublickey(username, key);
}
return ssh;
}
@Override
public String toString() {
return String.format("SSHClient(timeout=%d)", timeoutMillis);
}
};
private void backoffForAttempt(int retryAttempt, String message) { private void backoffForAttempt(int retryAttempt, String message) {
backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message); backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message);
} }
@ -240,16 +199,17 @@ public class SshjSshClient implements SshClient {
logger.warn(from, "<< (%s) error closing connection", toString()); logger.warn(from, "<< (%s) error closing connection", toString());
} }
if (i + 1 == sshRetries) { if (i + 1 == sshRetries) {
throw propagate(from, errorMessage+" (out of retries - max "+sshRetries+")"); throw propagate(from, errorMessage + " (out of retries - max " + sshRetries + ")");
} else if (shouldRetry(from) || } else if (shouldRetry(from)
(Throwables2.getFirstThrowableOfType(from, IllegalStateException.class) != null)) { || (Throwables2.getFirstThrowableOfType(from, IllegalStateException.class) != null)) {
logger.info("<< " + errorMessage + " (attempt " + (i + 1) + " of " + sshRetries + "): " + from.getMessage()); logger.info("<< " + errorMessage + " (attempt " + (i + 1) + " of " + sshRetries + "): "
+ from.getMessage());
backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage()); backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage());
if (connection != sshConnection) if (connection != sshClientConnection)
connect(); connect();
continue; continue;
} else { } else {
throw propagate(from, errorMessage+" (not retryable)"); throw propagate(from, errorMessage + " (not retryable)");
} }
} }
} }
@ -259,7 +219,7 @@ public class SshjSshClient implements SshClient {
public void connect() { public void connect() {
try { try {
ssh = acquire(sshConnection); acquire(sshClientConnection);
} catch (Exception e) { } catch (Exception e) {
Throwables.propagate(e); Throwables.propagate(e);
} }
@ -282,7 +242,7 @@ public class SshjSshClient implements SshClient {
@Override @Override
public SFTPClient create() throws IOException { public SFTPClient create() throws IOException {
checkConnected(); checkConnected();
sftp = ssh.newSFTPClient(); sftp = sshClientConnection.ssh.newSFTPClient();
return sftp; return sftp;
} }
@ -310,7 +270,7 @@ public class SshjSshClient implements SshClient {
public Payload create() throws Exception { public Payload create() throws Exception {
sftp = acquire(sftpConnection); sftp = acquire(sftpConnection);
return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.getSFTPEngine().open(path) return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.getSFTPEngine().open(path)
.getInputStream(), sftp)); .getInputStream(), sftp));
} }
@Override @Override
@ -385,7 +345,7 @@ public class SshjSshClient implements SshClient {
@VisibleForTesting @VisibleForTesting
boolean shouldRetry(Exception from) { boolean shouldRetry(Exception from) {
Predicate<Throwable> predicate = retryAuth ? Predicates.<Throwable> or(retryPredicate, Predicate<Throwable> predicate = retryAuth ? Predicates.<Throwable> or(retryPredicate,
instanceOf(AuthorizationException.class), instanceOf(UserAuthException.class)) : retryPredicate; instanceOf(AuthorizationException.class), instanceOf(UserAuthException.class)) : retryPredicate;
if (any(getCausalChain(from), predicate)) if (any(getCausalChain(from), predicate))
return true; return true;
if (!retryableMessages.equals("")) if (!retryableMessages.equals(""))
@ -404,7 +364,7 @@ public class SshjSshClient implements SshClient {
@Override @Override
public boolean apply(Throwable arg0) { public boolean apply(Throwable arg0) {
return (arg0.toString().indexOf(input) != -1) return (arg0.toString().indexOf(input) != -1)
|| (arg0.getMessage() != null && arg0.getMessage().indexOf(input) != -1); || (arg0.getMessage() != null && arg0.getMessage().indexOf(input) != -1);
} }
}); });
@ -420,7 +380,7 @@ public class SshjSshClient implements SshClient {
if (e instanceof UserAuthException) if (e instanceof UserAuthException)
throw new AuthorizationException("(" + toString() + ") " + message, e); throw new AuthorizationException("(" + toString() + ") " + message, e);
throw e instanceof SshException ? SshException.class.cast(e) : new SshException( throw e instanceof SshException ? SshException.class.cast(e) : new SshException(
"(" + toString() + ") " + message, e); "(" + toString() + ") " + message, e);
} }
@Override @Override
@ -431,7 +391,7 @@ public class SshjSshClient implements SshClient {
@PreDestroy @PreDestroy
public void disconnect() { public void disconnect() {
try { try {
sshConnection.clear(); sshClientConnection.clear();
} catch (Exception e) { } catch (Exception e) {
Throwables.propagate(e); Throwables.propagate(e);
} }
@ -452,8 +412,8 @@ public class SshjSshClient implements SshClient {
@Override @Override
public Session create() throws Exception { public Session create() throws Exception {
checkConnected(); checkConnected();
session = ssh.startSession(); session = sshClientConnection.ssh.startSession();
session.allocatePTY("vt100", 80, 24, 0, 0, Collections.<PTYMode, Integer>emptyMap()); session.allocatePTY("vt100", 80, 24, 0, 0, Collections.<PTYMode, Integer> emptyMap());
return session; return session;
} }
@ -485,7 +445,7 @@ public class SshjSshClient implements SshClient {
session = acquire(execConnection()); session = acquire(execConnection());
Command output = session.exec(checkNotNull(command, "command")); Command output = session.exec(checkNotNull(command, "command"));
String outputString = IOUtils.readFully(output.getInputStream()).toString(); String outputString = IOUtils.readFully(output.getInputStream()).toString();
output.join(timeoutMillis, TimeUnit.SECONDS); output.join(sshClientConnection.getSessionTimeout(), TimeUnit.MILLISECONDS);
int errorStatus = output.getExitStatus(); int errorStatus = output.getExitStatus();
String errorString = IOUtils.readFully(output.getErrorStream()).toString(); String errorString = IOUtils.readFully(output.getErrorStream()).toString();
return new ExecResponse(outputString, errorString, errorStatus); return new ExecResponse(outputString, errorString, errorStatus);
@ -509,17 +469,21 @@ public class SshjSshClient implements SshClient {
return new Connection<Session>() { return new Connection<Session>() {
private Session session = null; private Session session = null;
private SSHClient sshClientConnection;
@Override @Override
public void clear() throws TransportException, ConnectionException { public void clear() throws TransportException, ConnectionException {
if (session != null) if (session != null)
session.close(); session.close();
if (sshClientConnection != null)
Closeables.closeQuietly(sshClientConnection);
} }
@Override @Override
public Session create() throws Exception { public Session create() throws Exception {
checkConnected(); this.sshClientConnection = acquire(SSHClientConnection.builder().fromSSHClientConnection(
session = ssh.startSession(); SshjSshClient.this.sshClientConnection).sessionTimeout(0).build());
session = sshClientConnection.startSession();
return session; return session;
} }
@ -587,7 +551,7 @@ public class SshjSshClient implements SshClient {
@Override @Override
public String getUsername() { public String getUsername() {
return this.username; return this.user;
} }
} }

View File

@ -65,9 +65,7 @@ public class SshjSshClientModule extends AbstractModule {
@Override @Override
public SshClient create(IPSocket socket, LoginCredentials credentials) { public SshClient create(IPSocket socket, LoginCredentials credentials) {
SshClient client = new SshjSshClient(backoffLimitedRetryHandler, socket, timeout, credentials.getUser(), SshClient client = new SshjSshClient(backoffLimitedRetryHandler, socket, credentials, timeout);
(credentials.getPrivateKey() == null) ? credentials.getPassword() : null,
credentials.getPrivateKey() != null ? credentials.getPrivateKey().getBytes() : null);
injector.injectMembers(client);// add logger injector.injectMembers(client);// add logger
return client; return client;
} }

View File

@ -170,8 +170,10 @@ public class SshjSshClientLiveTest {
: sshHost); : sshHost);
} }
public void testExecChannelTakesStdinAndNoEchoOfCharsInOuput() throws IOException { public void testExecChannelTakesStdinAndNoEchoOfCharsInOuputAndOutlivesClient() throws IOException {
ExecChannel response = setupClient().execChannel("cat <<EOF"); SshClient client = setupClient();
ExecChannel response = client.execChannel("cat <<EOF");
client.disconnect();
assertEquals(response.getExitStatus().get(), null); assertEquals(response.getExitStatus().get(), null);
try { try {
PrintStream printStream = new PrintStream(response.getInput()); PrintStream printStream = new PrintStream(response.getInput());

View File

@ -172,8 +172,8 @@ public class SshjSshClientTest {
ssh.disconnect(); ssh.disconnect();
expectLastCall().andThrow(new ConnectionException("disconnected")); expectLastCall().andThrow(new ConnectionException("disconnected"));
replay(ssh); replay(ssh);
ssh1.ssh = ssh; ssh1.sshClientConnection.ssh = ssh;
ssh1.sshConnection.clear(); ssh1.sshClientConnection.clear();
verify(ssh); verify(ssh);
} }
@ -186,8 +186,7 @@ public class SshjSshClientTest {
} }
public void testRetriesLoggedAtInfoWithCount() throws Exception { public void testRetriesLoggedAtInfoWithCount() throws Exception {
@SuppressWarnings("unchecked") SSHClientConnection mockConnection = createMock(SSHClientConnection.class);
SshjSshClient.Connection<net.schmizz.sshj.SSHClient> mockConnection = createMock(SshjSshClient.Connection.class);
net.schmizz.sshj.SSHClient mockClient = createMock(net.schmizz.sshj.SSHClient.class); net.schmizz.sshj.SSHClient mockClient = createMock(net.schmizz.sshj.SSHClient.class);
mockConnection.clear(); expectLastCall(); mockConnection.clear(); expectLastCall();
@ -199,14 +198,14 @@ public class SshjSshClientTest {
replay(mockConnection); replay(mockConnection);
replay(mockClient); replay(mockClient);
ssh.sshConnection = mockConnection; ssh.sshClientConnection = mockConnection;
BufferLogger logcheck = new BufferLogger(ssh.getClass().getCanonicalName()); BufferLogger logcheck = new BufferLogger(ssh.getClass().getCanonicalName());
ssh.logger = logcheck; ssh.logger = logcheck;
logcheck.setLevel(Level.INFO); logcheck.setLevel(Level.INFO);
ssh.connect(); ssh.connect();
Assert.assertEquals(ssh.ssh, mockClient); Assert.assertEquals(ssh.sshClientConnection, mockConnection);
verify(mockConnection); verify(mockConnection);
verify(mockClient); verify(mockClient);
Record r = logcheck.assertLogContains("attempt 1 of 5"); Record r = logcheck.assertLogContains("attempt 1 of 5");