JschSshClinet bug fix for exec method

This commit is contained in:
Dmitri Babaev 2011-06-01 22:26:08 +04:00
parent ebf3527595
commit 6dc6d3581f
1 changed files with 324 additions and 307 deletions

View File

@ -1,307 +1,324 @@
/** /**
* *
* Copyright (C) 2011 Cloud Conscious, LLC. <info@cloudconscious.com> * Copyright (C) 2011 Cloud Conscious, LLC. <info@cloudconscious.com>
* *
* ==================================================================== * ====================================================================
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
* ==================================================================== * ====================================================================
*/ */
package org.jclouds.ssh.jsch; package org.jclouds.ssh.jsch;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Predicates.instanceOf; import static com.google.common.base.Predicates.instanceOf;
import static com.google.common.base.Predicates.or; import static com.google.common.base.Predicates.or;
import static com.google.common.base.Throwables.getCausalChain; import static com.google.common.base.Throwables.getCausalChain;
import static com.google.common.base.Throwables.getRootCause; import static com.google.common.base.Throwables.getRootCause;
import static com.google.common.collect.Iterables.any; import static com.google.common.collect.Iterables.any;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.ConnectException; import java.net.ConnectException;
import java.util.Arrays; import java.util.Arrays;
import javax.annotation.PostConstruct; import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy; import javax.annotation.PreDestroy;
import javax.annotation.Resource; import javax.annotation.Resource;
import javax.inject.Named; import javax.inject.Named;
import org.apache.commons.io.input.ProxyInputStream; import org.apache.commons.io.input.ProxyInputStream;
import org.apache.commons.io.output.ByteArrayOutputStream; import org.apache.commons.io.output.ByteArrayOutputStream;
import org.jclouds.compute.domain.ExecResponse; import org.jclouds.compute.domain.ExecResponse;
import org.jclouds.http.handlers.BackoffLimitedRetryHandler; import org.jclouds.http.handlers.BackoffLimitedRetryHandler;
import org.jclouds.io.Payload; import org.jclouds.io.Payload;
import org.jclouds.io.Payloads; import org.jclouds.io.Payloads;
import org.jclouds.logging.Logger; import org.jclouds.logging.Logger;
import org.jclouds.net.IPSocket; import org.jclouds.net.IPSocket;
import org.jclouds.ssh.SshClient; import org.jclouds.ssh.SshClient;
import org.jclouds.ssh.SshException; import org.jclouds.ssh.SshException;
import org.jclouds.util.Strings2; import org.jclouds.util.Strings2;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Predicate; import com.google.common.base.Predicate;
import com.google.common.base.Splitter; import com.google.common.base.Splitter;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import com.google.common.io.Closeables; import com.google.common.io.Closeables;
import com.google.inject.Inject; import com.google.inject.Inject;
import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.ChannelSftp; import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException; import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session; import com.jcraft.jsch.Session;
import com.jcraft.jsch.SftpException; import com.jcraft.jsch.SftpException;
/** /**
* This class needs refactoring. It is not thread safe. * This class needs refactoring. It is not thread safe.
* *
* @author Adrian Cole * @author Adrian Cole
*/ */
public class JschSshClient implements SshClient { public class JschSshClient implements SshClient {
private final class CloseFtpChannelOnCloseInputStream extends ProxyInputStream { private final class CloseFtpChannelOnCloseInputStream extends ProxyInputStream {
private final ChannelSftp sftp; private final ChannelSftp sftp;
private CloseFtpChannelOnCloseInputStream(InputStream proxy, ChannelSftp sftp) { private CloseFtpChannelOnCloseInputStream(InputStream proxy, ChannelSftp sftp) {
super(proxy); super(proxy);
this.sftp = sftp; this.sftp = sftp;
} }
@Override @Override
public void close() throws IOException { public void close() throws IOException {
super.close(); super.close();
if (sftp != null) if (sftp != null)
sftp.disconnect(); sftp.disconnect();
} }
} }
private final String host; private final String host;
private final int port; private final int port;
private final String username; private final String username;
private final String password; private final String password;
@Inject(optional = true) @Inject(optional = true)
@Named("jclouds.ssh.max_retries") @Named("jclouds.ssh.max_retries")
@VisibleForTesting @VisibleForTesting
int sshRetries = 5; int sshRetries = 5;
@Inject(optional = true) @Inject(optional = true)
@Named("jclouds.ssh.retryable_messages") @Named("jclouds.ssh.retryable_messages")
@VisibleForTesting @VisibleForTesting
String retryableMessages = "invalid data,End of IO Stream Read,Connection reset,connection is closed by foreign host,socket is not established"; String retryableMessages = "invalid data,End of IO Stream Read,Connection reset,connection is closed by foreign host,socket is not established";
@Inject(optional = true) @Inject(optional = true)
@Named("jclouds.ssh.retry_predicate") @Named("jclouds.ssh.retry_predicate")
private Predicate<Throwable> retryPredicate = or(instanceOf(ConnectException.class), instanceOf(IOException.class)); private Predicate<Throwable> retryPredicate = or(instanceOf(ConnectException.class), instanceOf(IOException.class));
@Resource @Resource
@Named("jclouds.ssh") @Named("jclouds.ssh")
protected Logger logger = Logger.NULL; protected Logger logger = Logger.NULL;
private Session session; private Session session;
private final byte[] privateKey; private final byte[] privateKey;
final byte[] emptyPassPhrase = new byte[0]; final byte[] emptyPassPhrase = new byte[0];
private final int timeout; private final int timeout;
private final BackoffLimitedRetryHandler backoffLimitedRetryHandler; private final BackoffLimitedRetryHandler backoffLimitedRetryHandler;
public JschSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout, public JschSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout,
String username, String password, byte[] privateKey) { String username, String password, byte[] privateKey) {
this.host = checkNotNull(socket, "socket").getAddress(); this.host = checkNotNull(socket, "socket").getAddress();
checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort()); checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort());
checkArgument(password != null || privateKey != null, "you must specify a password or a key"); checkArgument(password != null || privateKey != null, "you must specify a password or a key");
this.port = socket.getPort(); this.port = socket.getPort();
this.username = checkNotNull(username, "username"); this.username = checkNotNull(username, "username");
this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler"); this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler");
this.timeout = timeout; this.timeout = timeout;
this.password = password; this.password = password;
this.privateKey = privateKey; this.privateKey = privateKey;
} }
public Payload get(String path) { public Payload get(String path) {
checkNotNull(path, "path"); checkNotNull(path, "path");
ChannelSftp sftp = getSftp(); ChannelSftp sftp = getSftp();
try { try {
return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.get(path), sftp)); return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.get(path), sftp));
} catch (SftpException e) { } catch (SftpException e) {
throw new SshException(String.format("%s@%s:%d: Error getting path: %s", username, host, port, path), e); throw new SshException(String.format("%s@%s:%d: Error getting path: %s", username, host, port, path), e);
} }
} }
@Override @Override
public void put(String path, Payload contents) { public void put(String path, Payload contents) {
checkNotNull(path, "path"); checkNotNull(path, "path");
checkNotNull(contents, "contents"); checkNotNull(contents, "contents");
ChannelSftp sftp = getSftp(); ChannelSftp sftp = getSftp();
try { try {
sftp.put(contents.getInput(), path); sftp.put(contents.getInput(), path);
} catch (SftpException e) { } catch (SftpException e) {
throw new SshException(String.format("%s@%s:%d: Error putting path: %s", username, host, port, path), e); throw new SshException(String.format("%s@%s:%d: Error putting path: %s", username, host, port, path), e);
} finally { } finally {
Closeables.closeQuietly(contents); Closeables.closeQuietly(contents);
} }
} }
@Override @Override
public void put(String path, String contents) { public void put(String path, String contents) {
put(path, Payloads.newStringPayload(checkNotNull(contents, "contents"))); put(path, Payloads.newStringPayload(checkNotNull(contents, "contents")));
} }
private ChannelSftp getSftp() { private ChannelSftp getSftp() {
checkConnected(); checkConnected();
logger.debug("%s@%s:%d: Opening sftp Channel.", username, host, port); logger.debug("%s@%s:%d: Opening sftp Channel.", username, host, port);
ChannelSftp sftp = null; ChannelSftp sftp = null;
try { try {
sftp = (ChannelSftp) session.openChannel("sftp"); sftp = (ChannelSftp) session.openChannel("sftp");
sftp.connect(); sftp.connect();
} catch (JSchException e) { } catch (JSchException e) {
throw new SshException(String.format("%s@%s:%d: Error connecting to sftp.", username, host, port), e); throw new SshException(String.format("%s@%s:%d: Error connecting to sftp.", username, host, port), e);
} }
return sftp; return sftp;
} }
private void checkConnected() { private void checkConnected() {
checkState(session != null && session.isConnected(), String.format("%s@%s:%d: SFTP not connected!", username, checkState(session != null && session.isConnected(), String.format("%s@%s:%d: SFTP not connected!", username,
host, port)); host, port));
} }
@PostConstruct @PostConstruct
public void connect() { public void connect() {
disconnect(); disconnect();
Exception e = null; Exception e = null;
RETRY_LOOP: for (int i = 0; i < sshRetries; i++) { RETRY_LOOP: for (int i = 0; i < sshRetries; i++) {
try { try {
newSession(); newSession();
e = null; e = null;
break RETRY_LOOP; break RETRY_LOOP;
} catch (Exception from) { } catch (Exception from) {
e = from; e = from;
disconnect(); disconnect();
if (i == sshRetries) if (i == sshRetries)
throw propagate(from); throw propagate(from);
if (shouldRetry(from)) { if (shouldRetry(from)) {
backoffForAttempt(i + 1, String.format("%s@%s:%d: connection error: %s", username, host, port, from backoffForAttempt(i + 1, String.format("%s@%s:%d: connection error: %s", username, host, port, from
.getMessage())); .getMessage()));
continue; continue;
} }
throw propagate(from); throw propagate(from);
} }
} }
if (e != null) if (e != null)
throw propagate(e); throw propagate(e);
} }
@VisibleForTesting @VisibleForTesting
boolean shouldRetry(Exception from) { boolean shouldRetry(Exception from) {
final String rootMessage = getRootCause(from).getMessage(); final String rootMessage = getRootCause(from).getMessage();
return any(getCausalChain(from), retryPredicate) return any(getCausalChain(from), retryPredicate)
|| Iterables.any(Splitter.on(",").split(retryableMessages), new Predicate<String>() { || Iterables.any(Splitter.on(",").split(retryableMessages), new Predicate<String>() {
@Override @Override
public boolean apply(String input) { public boolean apply(String input) {
return rootMessage.indexOf(input) != -1; return rootMessage.indexOf(input) != -1;
} }
}); });
} }
private void backoffForAttempt(int retryAttempt, String message) { private void backoffForAttempt(int retryAttempt, String message) {
backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message); backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message);
} }
private void newSession() throws JSchException { private void newSession() throws JSchException {
JSch jsch = new JSch(); JSch jsch = new JSch();
session = null; session = null;
try { try {
session = jsch.getSession(username, host, port); session = jsch.getSession(username, host, port);
if (timeout != 0) if (timeout != 0)
session.setTimeout(timeout); session.setTimeout(timeout);
logger.debug("%s@%s:%d: Session created.", username, host, port); logger.debug("%s@%s:%d: Session created.", username, host, port);
if (password != null) { if (password != null) {
session.setPassword(password); session.setPassword(password);
} else { } else {
// jsch wipes out your private key // jsch wipes out your private key
jsch.addIdentity(username, Arrays.copyOf(privateKey, privateKey.length), null, emptyPassPhrase); jsch.addIdentity(username, Arrays.copyOf(privateKey, privateKey.length), null, emptyPassPhrase);
} }
} catch (JSchException e) { } catch (JSchException e) {
throw new SshException(String.format("%s@%s:%d: Error creating session.", username, host, port), e); throw new SshException(String.format("%s@%s:%d: Error creating session.", username, host, port), e);
} }
java.util.Properties config = new java.util.Properties(); java.util.Properties config = new java.util.Properties();
config.put("StrictHostKeyChecking", "no"); config.put("StrictHostKeyChecking", "no");
session.setConfig(config); session.setConfig(config);
session.connect(); session.connect();
logger.debug("%s@%s:%d: Session connected.", username, host, port); logger.debug("%s@%s:%d: Session connected.", username, host, port);
} }
private SshException propagate(Exception e) { private SshException propagate(Exception e) {
throw new SshException(String.format("%s@%s:%d: Error connecting to session.", username, host, port), e); throw new SshException(String.format("%s@%s:%d: Error connecting to session.", username, host, port), e);
} }
@PreDestroy @PreDestroy
public void disconnect() { public void disconnect() {
if (session != null && session.isConnected()) { if (session != null && session.isConnected()) {
session.disconnect(); session.disconnect();
session = null; session = null;
} }
} }
public ExecResponse exec(String command) { public ExecResponse exec(String command) {
checkConnected(); checkConnected();
ChannelExec executor = null;
try { ChannelExec executor = null;
try { ByteArrayOutputStream error = null;
executor = (ChannelExec) session.openChannel("exec");
executor.setPty(true); int j = 0;
} catch (JSchException e) { do {
throw new SshException(String.format("%s@%s:%d: Error connecting to exec.", username, host, port), e); try {
} executor = (ChannelExec) session.openChannel("exec");
executor.setCommand(command); } catch (JSchException e) {
ByteArrayOutputStream error = new ByteArrayOutputStream(); // unrecoverable fail because ssh session closed
executor.setErrStream(error); throw new SshException(String.format("%s@%s:%d: Error connecting to exec.", username, host, port), e);
try { }
executor.connect();
String outputString = Strings2.toStringAndClose(executor.getInputStream()); error = new ByteArrayOutputStream();
String errorString = error.toString(); executor.setPty(true);
int errorStatus = executor.getExitStatus(); executor.setCommand(command);
int i = 0; executor.setErrStream(error);
while ((errorStatus = executor.getExitStatus()) == -1 && i < this.sshRetries)
backoffForAttempt(++i, String.format("%s@%s:%d: bad status: -1", username, host, port)); try {
if (errorStatus == -1) executor.connect();
throw new SshException(String.format("%s@%s:%d: received exit status %d executing %s", username, host, } catch (JSchException e) {
port, executor.getExitStatus(), command)); executor.disconnect();
return new ExecResponse(outputString, errorString, errorStatus); backoffForAttempt(++j, String.format("%s@%s:%d: Failed to connect ChannelExec", username, host, port));
} catch (Exception e) { }
throw new SshException(String } while (j < this.sshRetries && !executor.isConnected());
.format("%s@%s:%d: Error executing command: %s", username, host, port, command), e);
} if (!executor.isConnected())
} finally { throw new SshException(String.format("%s@%s:%d: Failed to connect ChannelExec executing %s",
if (executor != null) username, host, port, command));
executor.disconnect();
} try {
} String outputString = Strings2.toStringAndClose(executor.getInputStream());
String errorString = error.toString();
@Override int errorStatus = executor.getExitStatus();
public String getHostAddress() { int i = 0;
return this.host; while ((errorStatus = executor.getExitStatus()) == -1 && i < this.sshRetries)
} backoffForAttempt(++i, String.format("%s@%s:%d: bad status: -1", username, host, port));
if (errorStatus == -1)
@Override throw new SshException(String.format("%s@%s:%d: received exit status %d executing %s", username, host,
public String getUsername() { port, executor.getExitStatus(), command));
return this.username; return new ExecResponse(outputString, errorString, errorStatus);
} } catch (Exception e) {
throw new SshException(String
} .format("%s@%s:%d: Error executing command: %s", username, host, port, command), e);
}
finally {
executor.disconnect();
}
}
@Override
public String getHostAddress() {
return this.host;
}
@Override
public String getUsername() {
return this.username;
}
}