diff --git a/hadoop-mapreduce-project/CHANGES.txt b/hadoop-mapreduce-project/CHANGES.txt index b5c53d0dee4..b5d48625258 100644 --- a/hadoop-mapreduce-project/CHANGES.txt +++ b/hadoop-mapreduce-project/CHANGES.txt @@ -318,6 +318,8 @@ Release 2.1.0-beta - UNRELEASED MAPREDUCE-5184. Document compatibility for MapReduce applications in hadoop-2 vis-a-vis hadoop-1. (Zhijie Shen via acmurthy) + MAPREDUCE-5194. Heed interrupts during Fetcher shutdown. (cdouglas) + OPTIMIZATIONS MAPREDUCE-4974. Optimising the LineRecordReader initialize() method diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java index 06e518015f7..4e9059fff25 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/Fetcher.java @@ -84,6 +84,7 @@ class Fetcher extends Thread { private final SecretKey shuffleSecretKey; + protected HttpURLConnection connection; private volatile boolean stopped = false; private static boolean sslShuffle; @@ -93,12 +94,22 @@ class Fetcher extends Thread { ShuffleSchedulerImpl scheduler, MergeManager merger, Reporter reporter, ShuffleClientMetrics metrics, ExceptionReporter exceptionReporter, SecretKey shuffleKey) { + this(job, reduceId, scheduler, merger, reporter, metrics, + exceptionReporter, shuffleKey, ++nextId); + } + + @VisibleForTesting + Fetcher(JobConf job, TaskAttemptID reduceId, + ShuffleSchedulerImpl scheduler, MergeManager merger, + Reporter reporter, ShuffleClientMetrics metrics, + ExceptionReporter exceptionReporter, SecretKey shuffleKey, + int id) { this.reporter = reporter; this.scheduler = scheduler; this.merger = merger; this.metrics = metrics; this.exceptionReporter = exceptionReporter; - this.id = ++nextId; + this.id = id; this.reduce = reduceId.getTaskID().getId(); this.shuffleSecretKey = shuffleKey; ioErrs = reporter.getCounter(SHUFFLE_ERR_GRP_NAME, @@ -166,6 +177,15 @@ class Fetcher extends Thread { } } + @Override + public void interrupt() { + try { + closeConnection(); + } finally { + super.interrupt(); + } + } + public void shutDown() throws InterruptedException { this.stopped = true; interrupt(); @@ -180,7 +200,8 @@ class Fetcher extends Thread { } @VisibleForTesting - protected HttpURLConnection openConnection(URL url) throws IOException { + protected synchronized void openConnection(URL url) + throws IOException { HttpURLConnection conn = (HttpURLConnection) url.openConnection(); if (sslShuffle) { HttpsURLConnection httpsConn = (HttpsURLConnection) conn; @@ -191,9 +212,24 @@ class Fetcher extends Thread { } httpsConn.setHostnameVerifier(sslFactory.getHostnameVerifier()); } - return conn; + connection = conn; } - + + protected synchronized void closeConnection() { + // Note that HttpURLConnection::disconnect() doesn't trash the object. + // connect() attempts to reconnect in a loop, possibly reversing this + if (connection != null) { + connection.disconnect(); + } + } + + private void abortConnect(MapHost host, Set remaining) { + for (TaskAttemptID left : remaining) { + scheduler.putBackKnownMapOutput(host, left); + } + closeConnection(); + } + /** * The crux of the matter... * @@ -220,11 +256,14 @@ class Fetcher extends Thread { Set remaining = new HashSet(maps); // Construct the url and connect - DataInputStream input; - + DataInputStream input = null; try { URL url = getMapOutputURL(host, maps); - HttpURLConnection connection = openConnection(url); + openConnection(url); + if (stopped) { + abortConnect(host, remaining); + return; + } // generate hash of the url String msgToEncode = SecureShuffleUtils.buildMsgFrom(url); @@ -237,6 +276,11 @@ class Fetcher extends Thread { // set the read timeout connection.setReadTimeout(readTimeout); connect(connection, connectionTimeout); + // verify that the thread wasn't stopped during calls to connect + if (stopped) { + abortConnect(host, remaining); + return; + } input = new DataInputStream(connection.getInputStream()); // Validate response code @@ -292,15 +336,19 @@ class Fetcher extends Thread { scheduler.copyFailed(left, host, true, false); } } - - IOUtils.cleanup(LOG, input); - + // Sanity check if (failedTasks == null && !remaining.isEmpty()) { throw new IOException("server didn't return all expected map outputs: " + remaining.size() + " left."); } + input.close(); + input = null; } finally { + if (input != null) { + IOUtils.cleanup(LOG, input); + input = null; + } for (TaskAttemptID left : remaining) { scheduler.putBackKnownMapOutput(host, left); } diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/OnDiskMapOutput.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/OnDiskMapOutput.java index 68713d392f3..59bb04a9dea 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/OnDiskMapOutput.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/OnDiskMapOutput.java @@ -39,11 +39,13 @@ import org.apache.hadoop.mapred.MapOutputFile; import org.apache.hadoop.mapreduce.TaskAttemptID; import org.apache.hadoop.mapreduce.task.reduce.MergeManagerImpl.CompressAwarePath; +import com.google.common.annotations.VisibleForTesting; + @InterfaceAudience.Private @InterfaceStability.Unstable class OnDiskMapOutput extends MapOutput { private static final Log LOG = LogFactory.getLog(OnDiskMapOutput.class); - private final FileSystem localFS; + private final FileSystem fs; private final Path tmpOutputPath; private final Path outputPath; private final MergeManagerImpl merger; @@ -51,20 +53,34 @@ class OnDiskMapOutput extends MapOutput { private long compressedSize; public OnDiskMapOutput(TaskAttemptID mapId, TaskAttemptID reduceId, - MergeManagerImpl merger, long size, + MergeManagerImpl merger, long size, JobConf conf, MapOutputFile mapOutputFile, int fetcher, boolean primaryMapOutput) - throws IOException { - super(mapId, size, primaryMapOutput); - this.merger = merger; - this.localFS = FileSystem.getLocal(conf); - outputPath = - mapOutputFile.getInputFileForWrite(mapId.getTaskID(),size); - tmpOutputPath = outputPath.suffix(String.valueOf(fetcher)); - - disk = localFS.create(tmpOutputPath); + throws IOException { + this(mapId, reduceId, merger, size, conf, mapOutputFile, fetcher, + primaryMapOutput, FileSystem.getLocal(conf), + mapOutputFile.getInputFileForWrite(mapId.getTaskID(), size)); + } + @VisibleForTesting + OnDiskMapOutput(TaskAttemptID mapId, TaskAttemptID reduceId, + MergeManagerImpl merger, long size, + JobConf conf, + MapOutputFile mapOutputFile, + int fetcher, boolean primaryMapOutput, + FileSystem fs, Path outputPath) throws IOException { + super(mapId, size, primaryMapOutput); + this.fs = fs; + this.merger = merger; + this.outputPath = outputPath; + tmpOutputPath = getTempPath(outputPath, fetcher); + disk = fs.create(tmpOutputPath); + } + + @VisibleForTesting + static Path getTempPath(Path outPath, int fetcher) { + return outPath.suffix(String.valueOf(fetcher)); } @Override @@ -114,7 +130,7 @@ class OnDiskMapOutput extends MapOutput { @Override public void commit() throws IOException { - localFS.rename(tmpOutputPath, outputPath); + fs.rename(tmpOutputPath, outputPath); CompressAwarePath compressAwarePath = new CompressAwarePath(outputPath, getSize(), this.compressedSize); merger.closeOnDiskFile(compressAwarePath); @@ -123,7 +139,7 @@ class OnDiskMapOutput extends MapOutput { @Override public void abort() { try { - localFS.delete(tmpOutputPath, false); + fs.delete(tmpOutputPath, false); } catch (IOException ie) { LOG.info("failure to clean up " + tmpOutputPath, ie); } @@ -133,4 +149,5 @@ class OnDiskMapOutput extends MapOutput { public String getDescription() { return "DISK"; } + } diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/TestFetcher.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/TestFetcher.java index 8a37399fe60..fdc9d98e539 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/TestFetcher.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/TestFetcher.java @@ -18,6 +18,23 @@ package org.apache.hadoop.mapreduce.task.reduce; +import java.io.FilterInputStream; + +import java.lang.Void; + +import java.net.HttpURLConnection; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapred.MapOutputFile; +import org.apache.hadoop.mapreduce.TaskID; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.TestName; +import static org.junit.Assert.*; + import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; @@ -26,7 +43,6 @@ import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; -import java.net.HttpURLConnection; import java.net.SocketTimeoutException; import java.net.URL; import java.util.ArrayList; @@ -37,7 +53,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.Counters; -import org.apache.hadoop.mapred.IFileOutputStream; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.Reporter; import org.apache.hadoop.mapreduce.TaskAttemptID; @@ -45,69 +60,68 @@ import org.apache.hadoop.mapreduce.security.SecureShuffleUtils; import org.apache.hadoop.mapreduce.security.token.JobTokenSecretManager; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + /** * Test that the Fetcher does what we expect it to. */ public class TestFetcher { private static final Log LOG = LogFactory.getLog(TestFetcher.class); + JobConf job = null; + TaskAttemptID id = null; + ShuffleSchedulerImpl ss = null; + MergeManagerImpl mm = null; + Reporter r = null; + ShuffleClientMetrics metrics = null; + ExceptionReporter except = null; + SecretKey key = null; + HttpURLConnection connection = null; + Counters.Counter allErrs = null; - public static class FakeFetcher extends Fetcher { + final String encHash = "vFE234EIFCiBgYs2tCXY/SjT8Kg="; + final MapHost host = new MapHost("localhost", "http://localhost:8080/"); + final TaskAttemptID map1ID = TaskAttemptID.forName("attempt_0_1_m_1_1"); + final TaskAttemptID map2ID = TaskAttemptID.forName("attempt_0_1_m_2_1"); - private HttpURLConnection connection; + @Rule public TestName name = new TestName(); - public FakeFetcher(JobConf job, TaskAttemptID reduceId, - ShuffleSchedulerImpl scheduler, MergeManagerImpl merger, - Reporter reporter, ShuffleClientMetrics metrics, - ExceptionReporter exceptionReporter, SecretKey jobTokenSecret, - HttpURLConnection connection) { - super(job, reduceId, scheduler, merger, reporter, metrics, exceptionReporter, - jobTokenSecret); - this.connection = connection; - } - - @Override - protected HttpURLConnection openConnection(URL url) throws IOException { - if(connection != null) { - return connection; - } - return super.openConnection(url); - } + @Before + @SuppressWarnings("unchecked") // mocked generics + public void setup() { + LOG.info(">>>> " + name.getMethodName()); + job = new JobConf(); + id = TaskAttemptID.forName("attempt_0_1_r_1_1"); + ss = mock(ShuffleSchedulerImpl.class); + mm = mock(MergeManagerImpl.class); + r = mock(Reporter.class); + metrics = mock(ShuffleClientMetrics.class); + except = mock(ExceptionReporter.class); + key = JobTokenSecretManager.createSecretKey(new byte[]{0,0,0,0}); + connection = mock(HttpURLConnection.class); + + allErrs = mock(Counters.Counter.class); + when(r.getCounter(anyString(), anyString())).thenReturn(allErrs); + + ArrayList maps = new ArrayList(1); + maps.add(map1ID); + maps.add(map2ID); + when(ss.getMapsForHost(host)).thenReturn(maps); + } + + @After + public void teardown() { + LOG.info("<<<< " + name.getMethodName()); } - @SuppressWarnings("unchecked") @Test(timeout=30000) public void testCopyFromHostConnectionTimeout() throws Exception { - LOG.info("testCopyFromHostConnectionTimeout"); - JobConf job = new JobConf(); - TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1"); - ShuffleSchedulerImpl ss = mock(ShuffleSchedulerImpl.class); - MergeManagerImpl mm = mock(MergeManagerImpl.class); - Reporter r = mock(Reporter.class); - ShuffleClientMetrics metrics = mock(ShuffleClientMetrics.class); - ExceptionReporter except = mock(ExceptionReporter.class); - SecretKey key = JobTokenSecretManager.createSecretKey(new byte[]{0,0,0,0}); - HttpURLConnection connection = mock(HttpURLConnection.class); when(connection.getInputStream()).thenThrow( new SocketTimeoutException("This is a fake timeout :)")); - Counters.Counter allErrs = mock(Counters.Counter.class); - when(r.getCounter(anyString(), anyString())) - .thenReturn(allErrs); - Fetcher underTest = new FakeFetcher(job, id, ss, mm, r, metrics, except, key, connection); - MapHost host = new MapHost("localhost", "http://localhost:8080/"); - - ArrayList maps = new ArrayList(1); - TaskAttemptID map1ID = TaskAttemptID.forName("attempt_0_1_m_1_1"); - maps.add(map1ID); - TaskAttemptID map2ID = TaskAttemptID.forName("attempt_0_1_m_2_1"); - maps.add(map2ID); - when(ss.getMapsForHost(host)).thenReturn(maps); - - String encHash = "vFE234EIFCiBgYs2tCXY/SjT8Kg="; - underTest.copyFromHost(host); verify(connection) @@ -122,38 +136,11 @@ public class TestFetcher { verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); } - @SuppressWarnings("unchecked") @Test public void testCopyFromHostBogusHeader() throws Exception { - LOG.info("testCopyFromHostBogusHeader"); - JobConf job = new JobConf(); - TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1"); - ShuffleSchedulerImpl ss = mock(ShuffleSchedulerImpl.class); - MergeManagerImpl mm = mock(MergeManagerImpl.class); - Reporter r = mock(Reporter.class); - ShuffleClientMetrics metrics = mock(ShuffleClientMetrics.class); - ExceptionReporter except = mock(ExceptionReporter.class); - SecretKey key = JobTokenSecretManager.createSecretKey(new byte[]{0,0,0,0}); - HttpURLConnection connection = mock(HttpURLConnection.class); - - Counters.Counter allErrs = mock(Counters.Counter.class); - when(r.getCounter(anyString(), anyString())) - .thenReturn(allErrs); - Fetcher underTest = new FakeFetcher(job, id, ss, mm, r, metrics, except, key, connection); - - MapHost host = new MapHost("localhost", "http://localhost:8080/"); - - ArrayList maps = new ArrayList(1); - TaskAttemptID map1ID = TaskAttemptID.forName("attempt_0_1_m_1_1"); - maps.add(map1ID); - TaskAttemptID map2ID = TaskAttemptID.forName("attempt_0_1_m_2_1"); - maps.add(map2ID); - when(ss.getMapsForHost(host)).thenReturn(maps); - - String encHash = "vFE234EIFCiBgYs2tCXY/SjT8Kg="; String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); @@ -177,38 +164,11 @@ public class TestFetcher { verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); } - @SuppressWarnings("unchecked") @Test public void testCopyFromHostWait() throws Exception { - LOG.info("testCopyFromHostWait"); - JobConf job = new JobConf(); - TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1"); - ShuffleSchedulerImpl ss = mock(ShuffleSchedulerImpl.class); - MergeManagerImpl mm = mock(MergeManagerImpl.class); - Reporter r = mock(Reporter.class); - ShuffleClientMetrics metrics = mock(ShuffleClientMetrics.class); - ExceptionReporter except = mock(ExceptionReporter.class); - SecretKey key = JobTokenSecretManager.createSecretKey(new byte[]{0,0,0,0}); - HttpURLConnection connection = mock(HttpURLConnection.class); - - Counters.Counter allErrs = mock(Counters.Counter.class); - when(r.getCounter(anyString(), anyString())) - .thenReturn(allErrs); - Fetcher underTest = new FakeFetcher(job, id, ss, mm, r, metrics, except, key, connection); - - MapHost host = new MapHost("localhost", "http://localhost:8080/"); - - ArrayList maps = new ArrayList(1); - TaskAttemptID map1ID = TaskAttemptID.forName("attempt_0_1_m_1_1"); - maps.add(map1ID); - TaskAttemptID map2ID = TaskAttemptID.forName("attempt_0_1_m_2_1"); - maps.add(map2ID); - when(ss.getMapsForHost(host)).thenReturn(maps); - - String encHash = "vFE234EIFCiBgYs2tCXY/SjT8Kg="; String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); @@ -235,112 +195,15 @@ public class TestFetcher { verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map1ID)); verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); } - @SuppressWarnings("unchecked") - @Test - public void testCopyFromHostExtraBytes() throws Exception { - LOG.info("testCopyFromHostWaitExtraBytes"); - JobConf job = new JobConf(); - TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1"); - ShuffleSchedulerImpl ss = mock(ShuffleSchedulerImpl.class); - MergeManagerImpl mm = mock(MergeManagerImpl.class); - InMemoryMapOutput immo = mock(InMemoryMapOutput.class); - - Reporter r = mock(Reporter.class); - ShuffleClientMetrics metrics = mock(ShuffleClientMetrics.class); - ExceptionReporter except = mock(ExceptionReporter.class); - SecretKey key = JobTokenSecretManager.createSecretKey(new byte[]{0,0,0,0}); - HttpURLConnection connection = mock(HttpURLConnection.class); - - Counters.Counter allErrs = mock(Counters.Counter.class); - when(r.getCounter(anyString(), anyString())) - .thenReturn(allErrs); - - Fetcher underTest = new FakeFetcher(job, id, ss, mm, - r, metrics, except, key, connection); - - MapHost host = new MapHost("localhost", "http://localhost:8080/"); - - ArrayList maps = new ArrayList(1); - TaskAttemptID map1ID = TaskAttemptID.forName("attempt_0_1_m_1_1"); - maps.add(map1ID); - TaskAttemptID map2ID = TaskAttemptID.forName("attempt_0_1_m_2_1"); - maps.add(map2ID); - when(ss.getMapsForHost(host)).thenReturn(maps); - - String encHash = "vFE234EIFCiBgYs2tCXY/SjT8Kg="; - String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); - - when(connection.getResponseCode()).thenReturn(200); - when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)) - .thenReturn(replyHash); - ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 14, 10, 1); - - ByteArrayOutputStream bout = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(bout); - IFileOutputStream ios = new IFileOutputStream(dos); - header.write(dos); - ios.write("MAPDATA123".getBytes()); - ios.finish(); - - ShuffleHeader header2 = new ShuffleHeader(map2ID.toString(), 14, 10, 1); - IFileOutputStream ios2 = new IFileOutputStream(dos); - header2.write(dos); - ios2.write("MAPDATA456".getBytes()); - ios2.finish(); - - ByteArrayInputStream in = new ByteArrayInputStream(bout.toByteArray()); - when(connection.getInputStream()).thenReturn(in); - // 8 < 10 therefore there appear to be extra bytes in the IFileInputStream - InMemoryMapOutput mapOut = new InMemoryMapOutput(job, map1ID, mm, 8, null, true ); - InMemoryMapOutput mapOut2 = new InMemoryMapOutput(job, map2ID, mm, 10, null, true ); - - when(mm.reserve(eq(map1ID), anyLong(), anyInt())).thenReturn(mapOut); - when(mm.reserve(eq(map2ID), anyLong(), anyInt())).thenReturn(mapOut2); - - - underTest.copyFromHost(host); - - - verify(allErrs).increment(1); - verify(ss).copyFailed(map1ID, host, true, false); - verify(ss, never()).copyFailed(map2ID, host, true, false); - - verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map1ID)); - verify(ss).putBackKnownMapOutput(any(MapHost.class), eq(map2ID)); - } @SuppressWarnings("unchecked") @Test(timeout=10000) public void testCopyFromHostCompressFailure() throws Exception { - LOG.info("testCopyFromHostCompressFailure"); - JobConf job = new JobConf(); - TaskAttemptID id = TaskAttemptID.forName("attempt_0_1_r_1_1"); - ShuffleSchedulerImpl ss = mock(ShuffleSchedulerImpl.class); - MergeManagerImpl mm = mock(MergeManagerImpl.class); InMemoryMapOutput immo = mock(InMemoryMapOutput.class); - Reporter r = mock(Reporter.class); - ShuffleClientMetrics metrics = mock(ShuffleClientMetrics.class); - ExceptionReporter except = mock(ExceptionReporter.class); - SecretKey key = JobTokenSecretManager.createSecretKey(new byte[]{0,0,0,0}); - HttpURLConnection connection = mock(HttpURLConnection.class); - - Counters.Counter allErrs = mock(Counters.Counter.class); - when(r.getCounter(anyString(), anyString())) - .thenReturn(allErrs); - + Fetcher underTest = new FakeFetcher(job, id, ss, mm, r, metrics, except, key, connection); - - MapHost host = new MapHost("localhost", "http://localhost:8080/"); - - ArrayList maps = new ArrayList(1); - TaskAttemptID map1ID = TaskAttemptID.forName("attempt_0_1_m_1_1"); - maps.add(map1ID); - TaskAttemptID map2ID = TaskAttemptID.forName("attempt_0_1_m_2_1"); - maps.add(map2ID); - when(ss.getMapsForHost(host)).thenReturn(maps); - String encHash = "vFE234EIFCiBgYs2tCXY/SjT8Kg="; String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); when(connection.getResponseCode()).thenReturn(200); @@ -366,4 +229,191 @@ public class TestFetcher { encHash); verify(ss, times(1)).copyFailed(map1ID, host, true, false); } + + @Test(timeout=10000) + public void testInterruptInMemory() throws Exception { + final int FETCHER = 2; + InMemoryMapOutput immo = spy(new InMemoryMapOutput( + job, id, mm, 100, null, true)); + when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())) + .thenReturn(immo); + doNothing().when(mm).waitForResource(); + when(ss.getHost()).thenReturn(host); + + String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); + when(connection.getResponseCode()).thenReturn(200); + when(connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)) + .thenReturn(replyHash); + ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + header.write(new DataOutputStream(bout)); + final StuckInputStream in = + new StuckInputStream(new ByteArrayInputStream(bout.toByteArray())); + when(connection.getInputStream()).thenReturn(in); + doAnswer(new Answer() { + public Void answer(InvocationOnMock ignore) throws IOException { + in.close(); + return null; + } + }).when(connection).disconnect(); + + Fetcher underTest = new FakeFetcher(job, id, ss, mm, + r, metrics, except, key, connection, FETCHER); + underTest.start(); + // wait for read in inputstream + in.waitForFetcher(); + underTest.shutDown(); + underTest.join(); // rely on test timeout to kill if stuck + + assertTrue(in.wasClosedProperly()); + verify(immo).abort(); + } + + @Test(timeout=10000) + public void testInterruptOnDisk() throws Exception { + final int FETCHER = 7; + Path p = new Path("file:///tmp/foo"); + Path pTmp = OnDiskMapOutput.getTempPath(p, FETCHER); + FileSystem mFs = mock(FileSystem.class, RETURNS_DEEP_STUBS); + MapOutputFile mof = mock(MapOutputFile.class); + when(mof.getInputFileForWrite(any(TaskID.class), anyLong())).thenReturn(p); + OnDiskMapOutput odmo = spy(new OnDiskMapOutput(map1ID, + id, mm, 100L, job, mof, FETCHER, true, mFs, p)); + when(mm.reserve(any(TaskAttemptID.class), anyLong(), anyInt())) + .thenReturn(odmo); + doNothing().when(mm).waitForResource(); + when(ss.getHost()).thenReturn(host); + + String replyHash = SecureShuffleUtils.generateHash(encHash.getBytes(), key); + when(connection.getResponseCode()).thenReturn(200); + when(connection.getHeaderField( + SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH)).thenReturn(replyHash); + ShuffleHeader header = new ShuffleHeader(map1ID.toString(), 10, 10, 1); + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + header.write(new DataOutputStream(bout)); + final StuckInputStream in = + new StuckInputStream(new ByteArrayInputStream(bout.toByteArray())); + when(connection.getInputStream()).thenReturn(in); + doAnswer(new Answer() { + public Void answer(InvocationOnMock ignore) throws IOException { + in.close(); + return null; + } + }).when(connection).disconnect(); + + Fetcher underTest = new FakeFetcher(job, id, ss, mm, + r, metrics, except, key, connection, FETCHER); + underTest.start(); + // wait for read in inputstream + in.waitForFetcher(); + underTest.shutDown(); + underTest.join(); // rely on test timeout to kill if stuck + + assertTrue(in.wasClosedProperly()); + verify(mFs).create(eq(pTmp)); + verify(mFs).delete(eq(pTmp), eq(false)); + verify(odmo).abort(); + } + + public static class FakeFetcher extends Fetcher { + + public FakeFetcher(JobConf job, TaskAttemptID reduceId, + ShuffleSchedulerImpl scheduler, MergeManagerImpl merger, + Reporter reporter, ShuffleClientMetrics metrics, + ExceptionReporter exceptionReporter, SecretKey jobTokenSecret, + HttpURLConnection connection) { + super(job, reduceId, scheduler, merger, reporter, metrics, + exceptionReporter, jobTokenSecret); + this.connection = connection; + } + + public FakeFetcher(JobConf job, TaskAttemptID reduceId, + ShuffleSchedulerImpl scheduler, MergeManagerImpl merger, + Reporter reporter, ShuffleClientMetrics metrics, + ExceptionReporter exceptionReporter, SecretKey jobTokenSecret, + HttpURLConnection connection, int id) { + super(job, reduceId, scheduler, merger, reporter, metrics, + exceptionReporter, jobTokenSecret, id); + this.connection = connection; + } + + @Override + protected void openConnection(URL url) throws IOException { + if (null == connection) { + super.openConnection(url); + } + // already 'opened' the mocked connection + return; + } + } + + static class StuckInputStream extends FilterInputStream { + + boolean stuck = false; + volatile boolean closed = false; + + StuckInputStream(InputStream inner) { + super(inner); + } + + int freeze() throws IOException { + synchronized (this) { + stuck = true; + notify(); + } + // connection doesn't throw InterruptedException, but may return some + // bytes geq 0 or throw an exception + while (!Thread.currentThread().isInterrupted() || closed) { + // spin + if (closed) { + throw new IOException("underlying stream closed, triggered an error"); + } + } + return 0; + } + + @Override + public int read() throws IOException { + int ret = super.read(); + if (ret != -1) { + return ret; + } + return freeze(); + } + + @Override + public int read(byte[] b) throws IOException { + int ret = super.read(b); + if (ret != -1) { + return ret; + } + return freeze(); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int ret = super.read(b, off, len); + if (ret != -1) { + return ret; + } + return freeze(); + } + + @Override + public void close() throws IOException { + closed = true; + } + + public synchronized void waitForFetcher() throws InterruptedException { + while (!stuck) { + wait(); + } + } + + public boolean wasClosedProperly() { + return closed; + } + + } + }