HADOOP-8350. Improve NetUtils.getInputStream to return a stream which has a tunable timeout. Contributed by Todd Lipcon.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1333649 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Todd Lipcon 2012-05-03 21:57:10 +00:00
parent 25882b199b
commit 03181022ab
7 changed files with 187 additions and 91 deletions

View File

@ -287,6 +287,9 @@ Release 2.0.0 - UNRELEASED
HADOOP-8347. Hadoop Common logs misspell 'successful'. HADOOP-8347. Hadoop Common logs misspell 'successful'.
(Philip Zeyliger via eli) (Philip Zeyliger via eli)
HADOOP-8350. Improve NetUtils.getInputStream to return a stream which has
a tunable timeout. (todd)
OPTIMIZATIONS OPTIMIZATIONS
BUG FIXES BUG FIXES

View File

@ -375,53 +375,44 @@ public class NetUtils {
} }
/** /**
* Same as getInputStream(socket, socket.getSoTimeout()).<br><br> * Same as <code>getInputStream(socket, socket.getSoTimeout()).</code>
* * <br><br>
* From documentation for {@link #getInputStream(Socket, long)}:<br>
* Returns InputStream for the socket. If the socket has an associated
* SocketChannel then it returns a
* {@link SocketInputStream} with the given timeout. If the socket does not
* have a channel, {@link Socket#getInputStream()} is returned. In the later
* case, the timeout argument is ignored and the timeout set with
* {@link Socket#setSoTimeout(int)} applies for reads.<br><br>
*
* Any socket created using socket factories returned by {@link NetUtils},
* must use this interface instead of {@link Socket#getInputStream()}.
* *
* @see #getInputStream(Socket, long) * @see #getInputStream(Socket, long)
*
* @param socket
* @return InputStream for reading from the socket.
* @throws IOException
*/ */
public static InputStream getInputStream(Socket socket) public static SocketInputWrapper getInputStream(Socket socket)
throws IOException { throws IOException {
return getInputStream(socket, socket.getSoTimeout()); return getInputStream(socket, socket.getSoTimeout());
} }
/** /**
* Returns InputStream for the socket. If the socket has an associated * Return a {@link SocketInputWrapper} for the socket and set the given
* SocketChannel then it returns a * timeout. If the socket does not have an associated channel, then its socket
* {@link SocketInputStream} with the given timeout. If the socket does not * timeout will be set to the specified value. Otherwise, a
* have a channel, {@link Socket#getInputStream()} is returned. In the later * {@link SocketInputStream} will be created which reads with the configured
* case, the timeout argument is ignored and the timeout set with * timeout.
* {@link Socket#setSoTimeout(int)} applies for reads.<br><br>
* *
* Any socket created using socket factories returned by {@link NetUtils}, * Any socket created using socket factories returned by {@link #NetUtils},
* must use this interface instead of {@link Socket#getInputStream()}. * must use this interface instead of {@link Socket#getInputStream()}.
* *
* In general, this should be called only once on each socket: see the note
* in {@link SocketInputWrapper#setTimeout(long)} for more information.
*
* @see Socket#getChannel() * @see Socket#getChannel()
* *
* @param socket * @param socket
* @param timeout timeout in milliseconds. This may not always apply. zero * @param timeout timeout in milliseconds. zero for waiting as
* for waiting as long as necessary. * long as necessary.
* @return InputStream for reading from the socket. * @return SocketInputWrapper for reading from the socket.
* @throws IOException * @throws IOException
*/ */
public static InputStream getInputStream(Socket socket, long timeout) public static SocketInputWrapper getInputStream(Socket socket, long timeout)
throws IOException { throws IOException {
return (socket.getChannel() == null) ? InputStream stm = (socket.getChannel() == null) ?
socket.getInputStream() : new SocketInputStream(socket, timeout); socket.getInputStream() : new SocketInputStream(socket);
SocketInputWrapper w = new SocketInputWrapper(socket, stm);
w.setTimeout(timeout);
return w;
} }
/** /**

View File

@ -248,6 +248,10 @@ abstract class SocketIOWithTimeout {
} }
} }
public void setTimeout(long timeoutMs) {
this.timeout = timeoutMs;
}
private static String timeoutExceptionString(SelectableChannel channel, private static String timeoutExceptionString(SelectableChannel channel,
long timeout, int ops) { long timeout, int ops) {

View File

@ -28,9 +28,6 @@ import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectableChannel; import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
/** /**
* This implements an input stream that can have a timeout while reading. * This implements an input stream that can have a timeout while reading.
* This sets non-blocking flag on the socket channel. * This sets non-blocking flag on the socket channel.
@ -40,9 +37,7 @@ import org.apache.hadoop.classification.InterfaceStability;
* IllegalBlockingModeException. * IllegalBlockingModeException.
* Please use {@link SocketOutputStream} for writing. * Please use {@link SocketOutputStream} for writing.
*/ */
@InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"}) class SocketInputStream extends InputStream
@InterfaceStability.Unstable
public class SocketInputStream extends InputStream
implements ReadableByteChannel { implements ReadableByteChannel {
private Reader reader; private Reader reader;
@ -171,4 +166,8 @@ public class SocketInputStream extends InputStream
public void waitForReadable() throws IOException { public void waitForReadable() throws IOException {
reader.waitForIO(SelectionKey.OP_READ); reader.waitForIO(SelectionKey.OP_READ);
} }
public void setTimeout(long timeoutMs) {
reader.setTimeout(timeoutMs);
}
} }

View File

@ -25,11 +25,14 @@ import java.net.ConnectException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.NetworkInterface; import java.net.NetworkInterface;
import java.net.ServerSocket;
import java.net.Socket; import java.net.Socket;
import java.net.SocketException; import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.net.URI; import java.net.URI;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.concurrent.TimeUnit;
import junit.framework.AssertionFailedError; import junit.framework.AssertionFailedError;
@ -37,7 +40,11 @@ import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.security.NetUtilsTestResolver; import org.apache.hadoop.security.NetUtilsTestResolver;
import org.apache.hadoop.test.MultithreadedTestUtil.TestContext;
import org.apache.hadoop.test.MultithreadedTestUtil.TestingThread;
import org.junit.Assume;
import org.junit.Before; import org.junit.Before;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
@ -50,6 +57,13 @@ public class TestNetUtils {
private static final int LOCAL_PORT = 8080; private static final int LOCAL_PORT = 8080;
private static final String LOCAL_PORT_NAME = Integer.toString(LOCAL_PORT); private static final String LOCAL_PORT_NAME = Integer.toString(LOCAL_PORT);
/**
* Some slop around expected times when making sure timeouts behave
* as expected. We assume that they will be accurate to within
* this threshold.
*/
static final long TIME_FUDGE_MILLIS = 200;
/** /**
* Test that we can't accidentally connect back to the connecting socket due * Test that we can't accidentally connect back to the connecting socket due
* to a quirk in the TCP spec. * to a quirk in the TCP spec.
@ -81,6 +95,79 @@ public class TestNetUtils {
} }
} }
@Test
public void testSocketReadTimeoutWithChannel() throws Exception {
doSocketReadTimeoutTest(true);
}
@Test
public void testSocketReadTimeoutWithoutChannel() throws Exception {
doSocketReadTimeoutTest(false);
}
private void doSocketReadTimeoutTest(boolean withChannel)
throws IOException {
// Binding a ServerSocket is enough to accept connections.
// Rely on the backlog to accept for us.
ServerSocket ss = new ServerSocket(0);
Socket s;
if (withChannel) {
s = NetUtils.getDefaultSocketFactory(new Configuration())
.createSocket();
Assume.assumeNotNull(s.getChannel());
} else {
s = new Socket();
assertNull(s.getChannel());
}
SocketInputWrapper stm = null;
try {
NetUtils.connect(s, ss.getLocalSocketAddress(), 1000);
stm = NetUtils.getInputStream(s, 1000);
assertReadTimeout(stm, 1000);
// Change timeout, make sure it applies.
stm.setTimeout(1);
assertReadTimeout(stm, 1);
// If there is a channel, then setting the socket timeout
// should not matter. If there is not a channel, it will
// take effect.
s.setSoTimeout(1000);
if (withChannel) {
assertReadTimeout(stm, 1);
} else {
assertReadTimeout(stm, 1000);
}
} finally {
IOUtils.closeStream(stm);
IOUtils.closeSocket(s);
ss.close();
}
}
private void assertReadTimeout(SocketInputWrapper stm, int timeoutMillis)
throws IOException {
long st = System.nanoTime();
try {
stm.read();
fail("Didn't time out");
} catch (SocketTimeoutException ste) {
assertTimeSince(st, timeoutMillis);
}
}
private void assertTimeSince(long startNanos, int expectedMillis) {
long durationNano = System.nanoTime() - startNanos;
long millis = TimeUnit.MILLISECONDS.convert(
durationNano, TimeUnit.NANOSECONDS);
assertTrue("Expected " + expectedMillis + "ms, but took " + millis,
Math.abs(millis - expectedMillis) < TIME_FUDGE_MILLIS);
}
/** /**
* Test for { * Test for {
* @throws UnknownHostException @link NetUtils#getLocalInetAddress(String) * @throws UnknownHostException @link NetUtils#getLocalInetAddress(String)

View File

@ -19,6 +19,7 @@ package org.apache.hadoop.net;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.InterruptedIOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.net.SocketTimeoutException; import java.net.SocketTimeoutException;
import java.nio.channels.Pipe; import java.nio.channels.Pipe;
@ -26,8 +27,13 @@ import java.util.Arrays;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.test.MultithreadedTestUtil;
import org.apache.hadoop.test.MultithreadedTestUtil.TestContext;
import org.apache.hadoop.test.MultithreadedTestUtil.TestingThread;
import junit.framework.TestCase; import org.junit.Test;
import static org.junit.Assert.*;
/** /**
* This tests timout out from SocketInputStream and * This tests timout out from SocketInputStream and
@ -36,14 +42,17 @@ import junit.framework.TestCase;
* Normal read and write using these streams are tested by pretty much * Normal read and write using these streams are tested by pretty much
* every DFS unit test. * every DFS unit test.
*/ */
public class TestSocketIOWithTimeout extends TestCase { public class TestSocketIOWithTimeout {
static Log LOG = LogFactory.getLog(TestSocketIOWithTimeout.class); static Log LOG = LogFactory.getLog(TestSocketIOWithTimeout.class);
private static int TIMEOUT = 1*1000; private static int TIMEOUT = 1*1000;
private static String TEST_STRING = "1234567890"; private static String TEST_STRING = "1234567890";
private void doIO(InputStream in, OutputStream out) throws IOException { private MultithreadedTestUtil.TestContext ctx = new TestContext();
private void doIO(InputStream in, OutputStream out,
int expectedTimeout) throws IOException {
/* Keep on writing or reading until we get SocketTimeoutException. /* Keep on writing or reading until we get SocketTimeoutException.
* It expects this exception to occur within 100 millis of TIMEOUT. * It expects this exception to occur within 100 millis of TIMEOUT.
*/ */
@ -61,34 +70,15 @@ public class TestSocketIOWithTimeout extends TestCase {
long diff = System.currentTimeMillis() - start; long diff = System.currentTimeMillis() - start;
LOG.info("Got SocketTimeoutException as expected after " + LOG.info("Got SocketTimeoutException as expected after " +
diff + " millis : " + e.getMessage()); diff + " millis : " + e.getMessage());
assertTrue(Math.abs(TIMEOUT - diff) <= 200); assertTrue(Math.abs(expectedTimeout - diff) <=
TestNetUtils.TIME_FUDGE_MILLIS);
break; break;
} }
} }
} }
/** @Test
* Just reads one byte from the input stream. public void testSocketIOWithTimeout() throws Exception {
*/
static class ReadRunnable implements Runnable {
private InputStream in;
public ReadRunnable(InputStream in) {
this.in = in;
}
public void run() {
try {
in.read();
} catch (IOException e) {
LOG.info("Got expection while reading as expected : " +
e.getMessage());
return;
}
assertTrue(false);
}
}
public void testSocketIOWithTimeout() throws IOException {
// first open pipe: // first open pipe:
Pipe pipe = Pipe.open(); Pipe pipe = Pipe.open();
@ -96,7 +86,7 @@ public class TestSocketIOWithTimeout extends TestCase {
Pipe.SinkChannel sink = pipe.sink(); Pipe.SinkChannel sink = pipe.sink();
try { try {
InputStream in = new SocketInputStream(source, TIMEOUT); final InputStream in = new SocketInputStream(source, TIMEOUT);
OutputStream out = new SocketOutputStream(sink, TIMEOUT); OutputStream out = new SocketOutputStream(sink, TIMEOUT);
byte[] writeBytes = TEST_STRING.getBytes(); byte[] writeBytes = TEST_STRING.getBytes();
@ -105,37 +95,62 @@ public class TestSocketIOWithTimeout extends TestCase {
out.write(writeBytes); out.write(writeBytes);
out.write(byteWithHighBit); out.write(byteWithHighBit);
doIO(null, out); doIO(null, out, TIMEOUT);
in.read(readBytes); in.read(readBytes);
assertTrue(Arrays.equals(writeBytes, readBytes)); assertTrue(Arrays.equals(writeBytes, readBytes));
assertEquals(byteWithHighBit & 0xff, in.read()); assertEquals(byteWithHighBit & 0xff, in.read());
doIO(in, null); doIO(in, null, TIMEOUT);
// Change timeout on the read side.
((SocketInputStream)in).setTimeout(TIMEOUT * 2);
doIO(in, null, TIMEOUT * 2);
/* /*
* Verify that it handles interrupted threads properly. * Verify that it handles interrupted threads properly.
* Use a large timeout and expect the thread to return quickly. * Use a large timeout and expect the thread to return quickly
* upon interruption.
*/ */
in = new SocketInputStream(source, 0); ((SocketInputStream)in).setTimeout(0);
Thread thread = new Thread(new ReadRunnable(in)); TestingThread thread = new TestingThread(ctx) {
thread.start(); @Override
public void doWork() throws Exception {
try { try {
Thread.sleep(1000); in.read();
} catch (InterruptedException ignored) {} fail("Did not fail with interrupt");
} catch (InterruptedIOException ste) {
LOG.info("Got expection while reading as expected : " +
ste.getMessage());
}
}
};
ctx.addThread(thread);
ctx.startThreads();
// If the thread is interrupted before it calls read()
// then it throws ClosedByInterruptException due to
// some Java quirk. Waiting for it to call read()
// gets it into select(), so we get the expected
// InterruptedIOException.
Thread.sleep(1000);
thread.interrupt(); thread.interrupt();
ctx.stop();
try {
thread.join();
} catch (InterruptedException e) {
throw new IOException("Unexpected InterruptedException : " + e);
}
//make sure the channels are still open //make sure the channels are still open
assertTrue(source.isOpen()); assertTrue(source.isOpen());
assertTrue(sink.isOpen()); assertTrue(sink.isOpen());
// Nevertheless, the output stream is closed, because
// a partial write may have succeeded (see comment in
// SocketOutputStream#write(byte[]), int, int)
try {
out.write(1);
fail("Did not throw");
} catch (IOException ioe) {
GenericTestUtils.assertExceptionContains(
"stream is closed", ioe);
}
out.close(); out.close();
assertFalse(sink.isOpen()); assertFalse(sink.isOpen());

View File

@ -46,7 +46,7 @@ import org.apache.hadoop.hdfs.security.token.block.InvalidBlockTokenException;
import org.apache.hadoop.hdfs.server.common.HdfsServerConstants; import org.apache.hadoop.hdfs.server.common.HdfsServerConstants;
import org.apache.hadoop.hdfs.util.DirectBufferPool; import org.apache.hadoop.hdfs.util.DirectBufferPool;
import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.net.SocketInputStream; import org.apache.hadoop.net.SocketInputWrapper;
import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.util.DataChecksum; import org.apache.hadoop.util.DataChecksum;
@ -450,11 +450,8 @@ public class RemoteBlockReader2 implements BlockReader {
// //
// Get bytes in block, set streams // Get bytes in block, set streams
// //
Preconditions.checkArgument(sock.getChannel() != null, SocketInputWrapper sin = NetUtils.getInputStream(sock);
"Socket %s does not have an associated Channel.", ReadableByteChannel ch = sin.getReadableByteChannel();
sock);
SocketInputStream sin =
(SocketInputStream)NetUtils.getInputStream(sock);
DataInputStream in = new DataInputStream(sin); DataInputStream in = new DataInputStream(sin);
BlockOpResponseProto status = BlockOpResponseProto.parseFrom( BlockOpResponseProto status = BlockOpResponseProto.parseFrom(
@ -477,7 +474,7 @@ public class RemoteBlockReader2 implements BlockReader {
} }
return new RemoteBlockReader2(file, block.getBlockPoolId(), block.getBlockId(), return new RemoteBlockReader2(file, block.getBlockPoolId(), block.getBlockId(),
sin, checksum, verifyChecksum, startOffset, firstChunkOffset, len, sock); ch, checksum, verifyChecksum, startOffset, firstChunkOffset, len, sock);
} }
static void checkSuccess( static void checkSuccess(