diff --git a/CHANGES.txt b/CHANGES.txt index 192668c8a37..f30f860abfd 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -751,6 +751,9 @@ Release 0.22.0 - Unreleased HADOOP-7276. Hadoop native builds fail on ARM due to -m32 (Trevor Robinson via eli) + HADOOP-7121. Exceptions while serializing IPC call responses are not + handled well. (todd) + Release 0.21.1 - Unreleased IMPROVEMENTS diff --git a/src/java/org/apache/hadoop/ipc/Server.java b/src/java/org/apache/hadoop/ipc/Server.java index dd7313c134b..f1ba0aa5862 100644 --- a/src/java/org/apache/hadoop/ipc/Server.java +++ b/src/java/org/apache/hadoop/ipc/Server.java @@ -913,9 +913,9 @@ public abstract class Server { public UserGroupInformation attemptingUser = null; // user name before auth // Fake 'call' for failed authorization response - private static final int AUTHROIZATION_FAILED_CALLID = -1; + private static final int AUTHORIZATION_FAILED_CALLID = -1; private final Call authFailedCall = - new Call(AUTHROIZATION_FAILED_CALLID, null, this); + new Call(AUTHORIZATION_FAILED_CALLID, null, this); private ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream(); // Fake 'call' for SASL context setup private static final int SASL_CALLID = -33; @@ -1355,9 +1355,22 @@ public abstract class Server { if (LOG.isDebugEnabled()) LOG.debug(" got #" + id); + Writable param; + try { + param = ReflectionUtils.newInstance(paramClass, conf);//read param + param.readFields(dis); + } catch (Throwable t) { + LOG.warn("Unable to read call parameters for client " + + getHostAddress(), t); + final Call readParamsFailedCall = new Call(id, null, this); + ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); - Writable param = ReflectionUtils.newInstance(paramClass, conf);//read param - param.readFields(dis); + setupResponse(responseBuffer, readParamsFailedCall, Status.FATAL, null, + t.getClass().getName(), + "IPC server unable to read call parameters: " + t.getMessage()); + responder.doRespond(readParamsFailedCall); + return; + } Call call = new Call(id, param, this); callQueue.put(call); // queue the call; maybe blocked here @@ -1591,7 +1604,18 @@ public abstract class Server { out.writeInt(status.state); // write status if (status == Status.SUCCESS) { - rv.write(out); + try { + rv.write(out); + } catch (Throwable t) { + LOG.warn("Error serializing call response for call " + call, t); + // Call back to same function - this is OK since the + // buffer is reset at the top, and since status is changed + // to ERROR it won't infinite loop. + setupResponse(response, call, Status.ERROR, + null, t.getClass().getName(), + StringUtils.stringifyException(t)); + return; + } } else { WritableUtils.writeString(out, errorClass); WritableUtils.writeString(out, error); diff --git a/src/test/core/org/apache/hadoop/ipc/TestIPC.java b/src/test/core/org/apache/hadoop/ipc/TestIPC.java index 649aa5922a1..83d9d1b5f74 100644 --- a/src/test/core/org/apache/hadoop/ipc/TestIPC.java +++ b/src/test/core/org/apache/hadoop/ipc/TestIPC.java @@ -28,30 +28,37 @@ import org.apache.hadoop.net.NetUtils; import java.util.Random; import java.io.DataInput; import java.io.File; +import java.io.DataOutput; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketTimeoutException; import javax.net.SocketFactory; -import junit.framework.TestCase; +import org.junit.Test; +import static org.junit.Assert.*; import static org.mockito.Mockito.*; import org.apache.hadoop.conf.Configuration; import org.junit.Assume; /** Unit tests for IPC. */ -public class TestIPC extends TestCase { +public class TestIPC { public static final Log LOG = LogFactory.getLog(TestIPC.class); final private static Configuration conf = new Configuration(); final static private int PING_INTERVAL = 1000; final static private int MIN_SLEEP_TIME = 1000; + + /** + * Flag used to turn off the fault injection behavior + * of the various writables. + **/ + static boolean WRITABLE_FAULTS_ENABLED = true; static { Client.setPingInterval(conf, PING_INTERVAL); } - public TestIPC(String name) { super(name); } private static final Random RANDOM = new Random(); @@ -62,11 +69,19 @@ public class TestIPC extends TestCase { private static class TestServer extends Server { private boolean sleep; + private Class responseClass; - public TestServer(int handlerCount, boolean sleep) + public TestServer(int handlerCount, boolean sleep) throws IOException { + this(handlerCount, sleep, LongWritable.class, null); + } + + public TestServer(int handlerCount, boolean sleep, + Class paramClass, + Class responseClass) throws IOException { - super(ADDRESS, 0, LongWritable.class, handlerCount, conf); + super(ADDRESS, 0, paramClass, handlerCount, conf); this.sleep = sleep; + this.responseClass = responseClass; } @Override @@ -78,7 +93,15 @@ public class TestIPC extends TestCase { Thread.sleep(RANDOM.nextInt(PING_INTERVAL) + MIN_SLEEP_TIME); } catch (InterruptedException e) {} } - return param; // echo param as result + if (responseClass != null) { + try { + return responseClass.newInstance(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else { + return param; // echo param as result + } } } @@ -148,6 +171,7 @@ public class TestIPC extends TestCase { } } + @Test public void testSerial() throws Exception { testSerial(3, false, 2, 5, 100); testSerial(3, true, 2, 5, 10); @@ -180,6 +204,7 @@ public class TestIPC extends TestCase { server.stop(); } + @Test public void testParallel() throws Exception { testParallel(10, false, 2, 4, 2, 4, 100); } @@ -222,6 +247,7 @@ public class TestIPC extends TestCase { } } + @Test public void testStandAloneClient() throws Exception { testParallel(10, false, 2, 4, 2, 4, 100); Client client = new Client(LongWritable.class, conf); @@ -242,83 +268,179 @@ public class TestIPC extends TestCase { message.contains(causeText)); } } - - private static class LongErrorWritable extends LongWritable { - private final static String ERR_MSG = - "Come across an exception while reading"; - - LongErrorWritable() {} - - LongErrorWritable(long longValue) { - super(longValue); + + static void maybeThrowIOE() throws IOException { + if (WRITABLE_FAULTS_ENABLED) { + throw new IOException("Injected fault"); } - + } + + static void maybeThrowRTE() { + if (WRITABLE_FAULTS_ENABLED) { + throw new RuntimeException("Injected fault"); + } + } + + @SuppressWarnings("unused") + private static class IOEOnReadWritable extends LongWritable { + public IOEOnReadWritable() {} + public void readFields(DataInput in) throws IOException { super.readFields(in); - throw new IOException(ERR_MSG); + maybeThrowIOE(); } } - private static class LongRTEWritable extends LongWritable { - private final static String ERR_MSG = - "Come across an runtime exception while reading"; - - LongRTEWritable() {} - - LongRTEWritable(long longValue) { - super(longValue); - } + @SuppressWarnings("unused") + private static class RTEOnReadWritable extends LongWritable { + public RTEOnReadWritable() {} public void readFields(DataInput in) throws IOException { super.readFields(in); - throw new RuntimeException(ERR_MSG); + maybeThrowRTE(); + } + } + + @SuppressWarnings("unused") + private static class IOEOnWriteWritable extends LongWritable { + public IOEOnWriteWritable() {} + + @Override + public void write(DataOutput out) throws IOException { + super.write(out); + maybeThrowIOE(); } } - public void testErrorClient() throws Exception { + @SuppressWarnings("unused") + private static class RTEOnWriteWritable extends LongWritable { + public RTEOnWriteWritable() {} + + @Override + public void write(DataOutput out) throws IOException { + super.write(out); + maybeThrowRTE(); + } + } + + /** + * Generic test case for exceptions thrown at some point in the IPC + * process. + * + * @param clientParamClass - client writes this writable for parameter + * @param serverParamClass - server reads this writable for parameter + * @param serverResponseClass - server writes this writable for response + * @param clientResponseClass - client reads this writable for response + */ + private void doErrorTest( + Class clientParamClass, + Class serverParamClass, + Class serverResponseClass, + Class clientResponseClass) throws Exception { + // start server - Server server = new TestServer(1, false); + Server server = new TestServer(1, false, + serverParamClass, serverResponseClass); InetSocketAddress addr = NetUtils.getConnectAddress(server); server.start(); // start client - Client client = new Client(LongErrorWritable.class, conf); + WRITABLE_FAULTS_ENABLED = true; + Client client = new Client(clientResponseClass, conf); try { - client.call(new LongErrorWritable(RANDOM.nextLong()), - addr, null, null, 0, conf); - fail("Expected an exception to have been thrown"); - } catch (IOException e) { - // check error - Throwable cause = e.getCause(); - assertTrue(cause instanceof IOException); - assertEquals(LongErrorWritable.ERR_MSG, cause.getMessage()); + LongWritable param = clientParamClass.newInstance(); + + try { + client.call(param, addr, null, null, 0, conf); + fail("Expected an exception to have been thrown"); + } catch (Throwable t) { + assertExceptionContains(t, "Injected fault"); + } + + // Doing a second call with faults disabled should return fine -- + // ie the internal state of the client or server should not be broken + // by the failed call + WRITABLE_FAULTS_ENABLED = false; + client.call(param, addr, null, null, 0, conf); + + } finally { + server.stop(); } } + + @Test + public void testIOEOnClientWriteParam() throws Exception { + doErrorTest(IOEOnWriteWritable.class, + LongWritable.class, + LongWritable.class, + LongWritable.class); + } - public void testRuntimeExceptionWritable() throws Exception { - // start server - Server server = new TestServer(1, false); - InetSocketAddress addr = NetUtils.getConnectAddress(server); - server.start(); - - // start client - Client client = new Client(LongRTEWritable.class, conf); - try { - client.call(new LongRTEWritable(RANDOM.nextLong()), - addr, null, null, 0, conf); - fail("Expected an exception to have been thrown"); - } catch (IOException e) { - // check error - Throwable cause = e.getCause(); - assertTrue(cause instanceof IOException); - // it's double-wrapped - Throwable cause2 = cause.getCause(); - assertTrue(cause2 instanceof RuntimeException); - - assertEquals(LongRTEWritable.ERR_MSG, cause2.getMessage()); - } + @Test + public void testRTEOnClientWriteParam() throws Exception { + doErrorTest(RTEOnWriteWritable.class, + LongWritable.class, + LongWritable.class, + LongWritable.class); } + @Test + public void testIOEOnServerReadParam() throws Exception { + doErrorTest(LongWritable.class, + IOEOnReadWritable.class, + LongWritable.class, + LongWritable.class); + } + + @Test + public void testRTEOnServerReadParam() throws Exception { + doErrorTest(LongWritable.class, + RTEOnReadWritable.class, + LongWritable.class, + LongWritable.class); + } + + + @Test + public void testIOEOnServerWriteResponse() throws Exception { + doErrorTest(LongWritable.class, + LongWritable.class, + IOEOnWriteWritable.class, + LongWritable.class); + } + + @Test + public void testRTEOnServerWriteResponse() throws Exception { + doErrorTest(LongWritable.class, + LongWritable.class, + RTEOnWriteWritable.class, + LongWritable.class); + } + + @Test + public void testIOEOnClientReadResponse() throws Exception { + doErrorTest(LongWritable.class, + LongWritable.class, + LongWritable.class, + IOEOnReadWritable.class); + } + + @Test + public void testRTEOnClientReadResponse() throws Exception { + doErrorTest(LongWritable.class, + LongWritable.class, + LongWritable.class, + RTEOnReadWritable.class); + } + + private static void assertExceptionContains( + Throwable t, String substring) { + String msg = StringUtils.stringifyException(t); + assertTrue("Exception should contain substring '" + substring + "':\n" + + msg, msg.contains(substring)); + LOG.info("Got expected exception", t); + } + /** * Test that, if the socket factory throws an IOE, it properly propagates * to the client. @@ -384,9 +506,9 @@ public class TestIPC extends TestCase { public static void main(String[] args) throws Exception { - //new TestIPC("test").testSerial(5, false, 2, 10, 1000); + //new TestIPC().testSerial(5, false, 2, 10, 1000); - new TestIPC("test").testParallel(10, false, 2, 4, 2, 4, 1000); + new TestIPC().testParallel(10, false, 2, 4, 2, 4, 1000); }