HADOOP-7121. Exceptions while serializing IPC call responses are not handled well. Contributed by Todd Lipcon.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1129982 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Todd Lipcon 2011-06-01 02:00:57 +00:00
parent fe45b6ed79
commit e05a6d1dce
3 changed files with 216 additions and 67 deletions

View File

@ -751,6 +751,9 @@ Release 0.22.0 - Unreleased
HADOOP-7276. Hadoop native builds fail on ARM due to -m32 (Trevor Robinson HADOOP-7276. Hadoop native builds fail on ARM due to -m32 (Trevor Robinson
via eli) via eli)
HADOOP-7121. Exceptions while serializing IPC call responses are not
handled well. (todd)
Release 0.21.1 - Unreleased Release 0.21.1 - Unreleased
IMPROVEMENTS IMPROVEMENTS

View File

@ -913,9 +913,9 @@ public class Connection {
public UserGroupInformation attemptingUser = null; // user name before auth public UserGroupInformation attemptingUser = null; // user name before auth
// Fake 'call' for failed authorization response // Fake 'call' for failed authorization response
private static final int AUTHROIZATION_FAILED_CALLID = -1; private static final int AUTHORIZATION_FAILED_CALLID = -1;
private final Call authFailedCall = private final Call authFailedCall =
new Call(AUTHROIZATION_FAILED_CALLID, null, this); new Call(AUTHORIZATION_FAILED_CALLID, null, this);
private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream(); private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream();
// Fake 'call' for SASL context setup // Fake 'call' for SASL context setup
private static final int SASL_CALLID = -33; private static final int SASL_CALLID = -33;
@ -1355,9 +1355,22 @@ private void processData(byte[] buf) throws IOException, InterruptedException {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug(" got #" + id); LOG.debug(" got #" + id);
Writable param;
try {
param = ReflectionUtils.newInstance(paramClass, conf);//read param
param.readFields(dis);
} catch (Throwable t) {
LOG.warn("Unable to read call parameters for client " +
getHostAddress(), t);
final Call readParamsFailedCall = new Call(id, null, this);
ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream();
Writable param = ReflectionUtils.newInstance(paramClass, conf);//read param setupResponse(responseBuffer, readParamsFailedCall, Status.FATAL, null,
param.readFields(dis); t.getClass().getName(),
"IPC server unable to read call parameters: " + t.getMessage());
responder.doRespond(readParamsFailedCall);
return;
}
Call call = new Call(id, param, this); Call call = new Call(id, param, this);
callQueue.put(call); // queue the call; maybe blocked here callQueue.put(call); // queue the call; maybe blocked here
@ -1591,7 +1604,18 @@ private void setupResponse(ByteArrayOutputStream response,
out.writeInt(status.state); // write status out.writeInt(status.state); // write status
if (status == Status.SUCCESS) { if (status == Status.SUCCESS) {
rv.write(out); try {
rv.write(out);
} catch (Throwable t) {
LOG.warn("Error serializing call response for call " + call, t);
// Call back to same function - this is OK since the
// buffer is reset at the top, and since status is changed
// to ERROR it won't infinite loop.
setupResponse(response, call, Status.ERROR,
null, t.getClass().getName(),
StringUtils.stringifyException(t));
return;
}
} else { } else {
WritableUtils.writeString(out, errorClass); WritableUtils.writeString(out, errorClass);
WritableUtils.writeString(out, error); WritableUtils.writeString(out, error);

View File

@ -28,30 +28,37 @@
import java.util.Random; import java.util.Random;
import java.io.DataInput; import java.io.DataInput;
import java.io.File; import java.io.File;
import java.io.DataOutput;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketTimeoutException; import java.net.SocketTimeoutException;
import javax.net.SocketFactory; import javax.net.SocketFactory;
import junit.framework.TestCase; import org.junit.Test;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.junit.Assume; import org.junit.Assume;
/** Unit tests for IPC. */ /** Unit tests for IPC. */
public class TestIPC extends TestCase { public class TestIPC {
public static final Log LOG = public static final Log LOG =
LogFactory.getLog(TestIPC.class); LogFactory.getLog(TestIPC.class);
final private static Configuration conf = new Configuration(); final private static Configuration conf = new Configuration();
final static private int PING_INTERVAL = 1000; final static private int PING_INTERVAL = 1000;
final static private int MIN_SLEEP_TIME = 1000; final static private int MIN_SLEEP_TIME = 1000;
/**
* Flag used to turn off the fault injection behavior
* of the various writables.
**/
static boolean WRITABLE_FAULTS_ENABLED = true;
static { static {
Client.setPingInterval(conf, PING_INTERVAL); Client.setPingInterval(conf, PING_INTERVAL);
} }
public TestIPC(String name) { super(name); }
private static final Random RANDOM = new Random(); private static final Random RANDOM = new Random();
@ -62,11 +69,19 @@ public class TestIPC extends TestCase {
private static class TestServer extends Server { private static class TestServer extends Server {
private boolean sleep; private boolean sleep;
private Class<? extends Writable> responseClass;
public TestServer(int handlerCount, boolean sleep) public TestServer(int handlerCount, boolean sleep) throws IOException {
this(handlerCount, sleep, LongWritable.class, null);
}
public TestServer(int handlerCount, boolean sleep,
Class<? extends Writable> paramClass,
Class<? extends Writable> responseClass)
throws IOException { throws IOException {
super(ADDRESS, 0, LongWritable.class, handlerCount, conf); super(ADDRESS, 0, paramClass, handlerCount, conf);
this.sleep = sleep; this.sleep = sleep;
this.responseClass = responseClass;
} }
@Override @Override
@ -78,7 +93,15 @@ public Writable call(Class<?> protocol, Writable param, long receiveTime)
Thread.sleep(RANDOM.nextInt(PING_INTERVAL) + MIN_SLEEP_TIME); Thread.sleep(RANDOM.nextInt(PING_INTERVAL) + MIN_SLEEP_TIME);
} catch (InterruptedException e) {} } catch (InterruptedException e) {}
} }
return param; // echo param as result if (responseClass != null) {
try {
return responseClass.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
return param; // echo param as result
}
} }
} }
@ -148,6 +171,7 @@ public void run() {
} }
} }
@Test
public void testSerial() throws Exception { public void testSerial() throws Exception {
testSerial(3, false, 2, 5, 100); testSerial(3, false, 2, 5, 100);
testSerial(3, true, 2, 5, 10); testSerial(3, true, 2, 5, 10);
@ -180,6 +204,7 @@ public void testSerial(int handlerCount, boolean handlerSleep,
server.stop(); server.stop();
} }
@Test
public void testParallel() throws Exception { public void testParallel() throws Exception {
testParallel(10, false, 2, 4, 2, 4, 100); testParallel(10, false, 2, 4, 2, 4, 100);
} }
@ -222,6 +247,7 @@ public void testParallel(int handlerCount, boolean handlerSleep,
} }
} }
@Test
public void testStandAloneClient() throws Exception { public void testStandAloneClient() throws Exception {
testParallel(10, false, 2, 4, 2, 4, 100); testParallel(10, false, 2, 4, 2, 4, 100);
Client client = new Client(LongWritable.class, conf); Client client = new Client(LongWritable.class, conf);
@ -242,83 +268,179 @@ public void testStandAloneClient() throws Exception {
message.contains(causeText)); message.contains(causeText));
} }
} }
private static class LongErrorWritable extends LongWritable { static void maybeThrowIOE() throws IOException {
private final static String ERR_MSG = if (WRITABLE_FAULTS_ENABLED) {
"Come across an exception while reading"; throw new IOException("Injected fault");
LongErrorWritable() {}
LongErrorWritable(long longValue) {
super(longValue);
} }
}
static void maybeThrowRTE() {
if (WRITABLE_FAULTS_ENABLED) {
throw new RuntimeException("Injected fault");
}
}
@SuppressWarnings("unused")
private static class IOEOnReadWritable extends LongWritable {
public IOEOnReadWritable() {}
public void readFields(DataInput in) throws IOException { public void readFields(DataInput in) throws IOException {
super.readFields(in); super.readFields(in);
throw new IOException(ERR_MSG); maybeThrowIOE();
} }
} }
private static class LongRTEWritable extends LongWritable { @SuppressWarnings("unused")
private final static String ERR_MSG = private static class RTEOnReadWritable extends LongWritable {
"Come across an runtime exception while reading"; public RTEOnReadWritable() {}
LongRTEWritable() {}
LongRTEWritable(long longValue) {
super(longValue);
}
public void readFields(DataInput in) throws IOException { public void readFields(DataInput in) throws IOException {
super.readFields(in); super.readFields(in);
throw new RuntimeException(ERR_MSG); maybeThrowRTE();
}
}
@SuppressWarnings("unused")
private static class IOEOnWriteWritable extends LongWritable {
public IOEOnWriteWritable() {}
@Override
public void write(DataOutput out) throws IOException {
super.write(out);
maybeThrowIOE();
} }
} }
public void testErrorClient() throws Exception { @SuppressWarnings("unused")
private static class RTEOnWriteWritable extends LongWritable {
public RTEOnWriteWritable() {}
@Override
public void write(DataOutput out) throws IOException {
super.write(out);
maybeThrowRTE();
}
}
/**
* Generic test case for exceptions thrown at some point in the IPC
* process.
*
* @param clientParamClass - client writes this writable for parameter
* @param serverParamClass - server reads this writable for parameter
* @param serverResponseClass - server writes this writable for response
* @param clientResponseClass - client reads this writable for response
*/
private void doErrorTest(
Class<? extends LongWritable> clientParamClass,
Class<? extends LongWritable> serverParamClass,
Class<? extends LongWritable> serverResponseClass,
Class<? extends LongWritable> clientResponseClass) throws Exception {
// start server // start server
Server server = new TestServer(1, false); Server server = new TestServer(1, false,
serverParamClass, serverResponseClass);
InetSocketAddress addr = NetUtils.getConnectAddress(server); InetSocketAddress addr = NetUtils.getConnectAddress(server);
server.start(); server.start();
// start client // start client
Client client = new Client(LongErrorWritable.class, conf); WRITABLE_FAULTS_ENABLED = true;
Client client = new Client(clientResponseClass, conf);
try { try {
client.call(new LongErrorWritable(RANDOM.nextLong()), LongWritable param = clientParamClass.newInstance();
addr, null, null, 0, conf);
fail("Expected an exception to have been thrown"); try {
} catch (IOException e) { client.call(param, addr, null, null, 0, conf);
// check error fail("Expected an exception to have been thrown");
Throwable cause = e.getCause(); } catch (Throwable t) {
assertTrue(cause instanceof IOException); assertExceptionContains(t, "Injected fault");
assertEquals(LongErrorWritable.ERR_MSG, cause.getMessage()); }
// Doing a second call with faults disabled should return fine --
// ie the internal state of the client or server should not be broken
// by the failed call
WRITABLE_FAULTS_ENABLED = false;
client.call(param, addr, null, null, 0, conf);
} finally {
server.stop();
} }
} }
@Test
public void testIOEOnClientWriteParam() throws Exception {
doErrorTest(IOEOnWriteWritable.class,
LongWritable.class,
LongWritable.class,
LongWritable.class);
}
public void testRuntimeExceptionWritable() throws Exception { @Test
// start server public void testRTEOnClientWriteParam() throws Exception {
Server server = new TestServer(1, false); doErrorTest(RTEOnWriteWritable.class,
InetSocketAddress addr = NetUtils.getConnectAddress(server); LongWritable.class,
server.start(); LongWritable.class,
LongWritable.class);
// start client
Client client = new Client(LongRTEWritable.class, conf);
try {
client.call(new LongRTEWritable(RANDOM.nextLong()),
addr, null, null, 0, conf);
fail("Expected an exception to have been thrown");
} catch (IOException e) {
// check error
Throwable cause = e.getCause();
assertTrue(cause instanceof IOException);
// it's double-wrapped
Throwable cause2 = cause.getCause();
assertTrue(cause2 instanceof RuntimeException);
assertEquals(LongRTEWritable.ERR_MSG, cause2.getMessage());
}
} }
@Test
public void testIOEOnServerReadParam() throws Exception {
doErrorTest(LongWritable.class,
IOEOnReadWritable.class,
LongWritable.class,
LongWritable.class);
}
@Test
public void testRTEOnServerReadParam() throws Exception {
doErrorTest(LongWritable.class,
RTEOnReadWritable.class,
LongWritable.class,
LongWritable.class);
}
@Test
public void testIOEOnServerWriteResponse() throws Exception {
doErrorTest(LongWritable.class,
LongWritable.class,
IOEOnWriteWritable.class,
LongWritable.class);
}
@Test
public void testRTEOnServerWriteResponse() throws Exception {
doErrorTest(LongWritable.class,
LongWritable.class,
RTEOnWriteWritable.class,
LongWritable.class);
}
@Test
public void testIOEOnClientReadResponse() throws Exception {
doErrorTest(LongWritable.class,
LongWritable.class,
LongWritable.class,
IOEOnReadWritable.class);
}
@Test
public void testRTEOnClientReadResponse() throws Exception {
doErrorTest(LongWritable.class,
LongWritable.class,
LongWritable.class,
RTEOnReadWritable.class);
}
private static void assertExceptionContains(
Throwable t, String substring) {
String msg = StringUtils.stringifyException(t);
assertTrue("Exception should contain substring '" + substring + "':\n" +
msg, msg.contains(substring));
LOG.info("Got expected exception", t);
}
/** /**
* Test that, if the socket factory throws an IOE, it properly propagates * Test that, if the socket factory throws an IOE, it properly propagates
* to the client. * to the client.
@ -384,9 +506,9 @@ private long countOpenFileDescriptors() {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
//new TestIPC("test").testSerial(5, false, 2, 10, 1000); //new TestIPC().testSerial(5, false, 2, 10, 1000);
new TestIPC("test").testParallel(10, false, 2, 4, 2, 4, 1000); new TestIPC().testParallel(10, false, 2, 4, 2, 4, 1000);
} }