diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/Shell.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/Shell.java index 55f92d66638..7bc1ea20bc8 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/Shell.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/Shell.java @@ -27,7 +27,9 @@ import java.io.InterruptedIOException; import java.nio.charset.Charset; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.Map; +import java.util.Set; import java.util.Timer; import java.util.TimerTask; import java.util.WeakHashMap; @@ -50,8 +52,8 @@ import org.slf4j.LoggerFactory; @InterfaceAudience.Public @InterfaceStability.Evolving public abstract class Shell { - private static final Map CHILD_PROCESSES = - Collections.synchronizedMap(new WeakHashMap()); + private static final Map CHILD_SHELLS = + Collections.synchronizedMap(new WeakHashMap()); public static final Logger LOG = LoggerFactory.getLogger(Shell.class); /** @@ -820,6 +822,7 @@ public abstract class Shell { private File dir; private Process process; // sub process used to execute the command private int exitCode; + private Thread waitingThread; /** Flag to indicate whether or not the script has finished executing. */ private final AtomicBoolean completed = new AtomicBoolean(false); @@ -924,7 +927,9 @@ public abstract class Shell { } else { process = builder.start(); } - CHILD_PROCESSES.put(process, null); + + waitingThread = Thread.currentThread(); + CHILD_SHELLS.put(this, null); if (timeOutInterval > 0) { timeOutTimer = new Timer("Shell command timeout"); @@ -1021,7 +1026,8 @@ public abstract class Shell { LOG.warn("Error while closing the error stream", ioe); } process.destroy(); - CHILD_PROCESSES.remove(process); + waitingThread = null; + CHILD_SHELLS.remove(this); lastTime = Time.monotonicNow(); } } @@ -1069,6 +1075,15 @@ public abstract class Shell { return exitCode; } + /** get the thread that is waiting on this instance of Shell. + * @return the thread that ran runCommand() that spawned this shell + * or null if no thread is waiting for this shell to complete + */ + public Thread getWaitingThread() { + return waitingThread; + } + + /** * This is an IOException with exit code added. */ @@ -1322,20 +1337,27 @@ public abstract class Shell { } /** - * Static method to destroy all running Shell processes - * Iterates through a list of all currently running Shell - * processes and destroys them one by one. This method is thread safe and - * is intended to be used in a shutdown hook. + * Static method to destroy all running Shell processes. + * Iterates through a map of all currently running Shell + * processes and destroys them one by one. This method is thread safe */ - public static void destroyAllProcesses() { - synchronized (CHILD_PROCESSES) { - for (Process key : CHILD_PROCESSES.keySet()) { - Process process = key; - if (key != null) { - process.destroy(); + public static void destroyAllShellProcesses() { + synchronized (CHILD_SHELLS) { + for (Shell shell : CHILD_SHELLS.keySet()) { + if (shell.getProcess() != null) { + shell.getProcess().destroy(); } } - CHILD_PROCESSES.clear(); + CHILD_SHELLS.clear(); + } + } + + /** + * Static method to return a Set of all Shell objects. + */ + public static Set getAllShells() { + synchronized (CHILD_SHELLS) { + return new HashSet<>(CHILD_SHELLS.keySet()); } } } diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestShell.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestShell.java index 2707573ab4f..7e53883440c 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestShell.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestShell.java @@ -479,7 +479,7 @@ public class TestShell extends Assert { } @Test(timeout=120000) - public void testShellKillAllProcesses() throws Throwable { + public void testDestroyAllShellProcesses() throws Throwable { Assume.assumeFalse(WINDOWS); StringBuffer sleepCommand = new StringBuffer(); sleepCommand.append("sleep 200"); @@ -524,7 +524,7 @@ public class TestShell extends Assert { } }, 10, 10000); - Shell.destroyAllProcesses(); + Shell.destroyAllShellProcesses(); shexc1.getProcess().waitFor(); shexc2.getProcess().waitFor(); } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java index 04be6318c53..613c0a95628 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/ContainerLocalizer.java @@ -24,10 +24,13 @@ import java.net.InetSocketAddress; import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletionService; @@ -53,6 +56,7 @@ import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.util.DiskValidator; import org.apache.hadoop.util.DiskValidatorFactory; +import org.apache.hadoop.util.Shell; import org.apache.hadoop.util.concurrent.HadoopExecutors; import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler; import org.apache.hadoop.yarn.api.records.LocalResource; @@ -75,6 +79,8 @@ import org.apache.hadoop.yarn.util.FSDownload; import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import static org.apache.hadoop.util.Shell.getAllShells; + public class ContainerLocalizer { static final Log LOG = LogFactory.getLog(ContainerLocalizer.class); @@ -101,6 +107,9 @@ public class ContainerLocalizer { private final String appCacheDirContextName; private final DiskValidator diskValidator; + private Set localizingThreads = + Collections.synchronizedSet(new HashSet()); + public ContainerLocalizer(FileContext lfs, String user, String appId, String localizerId, List localDirs, RecordFactory recordFactory) throws IOException { @@ -178,13 +187,14 @@ public class ContainerLocalizer { exec = createDownloadThreadPool(); CompletionService ecs = createCompletionService(exec); localizeFiles(nodeManager, ecs, ugi); - return; } catch (Throwable e) { throw new IOException(e); } finally { try { if (exec != null) { - exec.shutdownNow(); + exec.shutdown(); + destroyShellProcesses(getAllShells()); + exec.awaitTermination(10, TimeUnit.SECONDS); } LocalDirAllocator.removeContext(appCacheDirContextName); } finally { @@ -202,10 +212,34 @@ public class ContainerLocalizer { return new ExecutorCompletionService(exec); } + class FSDownloadWrapper extends FSDownload { + + FSDownloadWrapper(FileContext files, UserGroupInformation ugi, + Configuration conf, Path destDirPath, LocalResource resource) { + super(files, ugi, conf, destDirPath, resource); + } + + @Override + public Path call() throws Exception { + Thread currentThread = Thread.currentThread(); + localizingThreads.add(currentThread); + try { + return doDownloadCall(); + } finally { + localizingThreads.remove(currentThread); + } + } + + Path doDownloadCall() throws Exception { + return super.call(); + } + + } + Callable download(Path path, LocalResource rsrc, UserGroupInformation ugi) throws IOException { diskValidator.checkStatus(new File(path.toUri().getRawPath())); - return new FSDownload(lfs, ugi, conf, path, rsrc); + return new FSDownloadWrapper(lfs, ugi, conf, path, rsrc); } static long getEstimatedSize(LocalResource rsrc) { @@ -363,6 +397,7 @@ public class ContainerLocalizer { public static void main(String[] argv) throws Throwable { Thread.setDefaultUncaughtExceptionHandler(new YarnUncaughtExceptionHandler()); + int nRet = 0; // usage: $0 user appId locId host port app_log_dir user_dir [user_dir]* // let $x = $x/usercache for $local.dir // MKDIR $x/$user/appcache/$appid @@ -399,7 +434,9 @@ public class ContainerLocalizer { // space in both DefaultCE and LCE cases e.printStackTrace(System.out); LOG.error("Exception in main:", e); - System.exit(-1); + nRet = -1; + } finally { + System.exit(nRet); } } @@ -436,4 +473,12 @@ public class ContainerLocalizer { lfs.setPermission(dirPath, perms); } } + + private void destroyShellProcesses(Set shells) { + for (Shell shell : shells) { + if(localizingThreads.contains(shell.getWaitingThread())) { + shell.getProcess().destroy(); + } + } + } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java index fac708655f5..4681b54ee93 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/localizer/TestContainerLocalizer.java @@ -17,6 +17,8 @@ */ package org.apache.hadoop.yarn.server.nodemanager.containermanager.localizer; +import static junit.framework.TestCase.assertFalse; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; @@ -25,6 +27,7 @@ import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isA; import static org.mockito.Matchers.same; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; @@ -46,6 +49,7 @@ import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import com.google.common.base.Supplier; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -60,6 +64,9 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.security.Credentials; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.test.GenericTestUtils; +import org.apache.hadoop.util.Shell; +import org.apache.hadoop.util.Shell.ShellCommandExecutor; import org.apache.hadoop.yarn.api.records.LocalResource; import org.apache.hadoop.yarn.api.records.LocalResourceType; import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; @@ -76,6 +83,7 @@ import org.apache.hadoop.yarn.util.ConverterUtils; import org.junit.Assert; import org.junit.Test; import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -92,18 +100,18 @@ public class TestContainerLocalizer { static final InetSocketAddress nmAddr = new InetSocketAddress("foobar", 8040); - private AbstractFileSystem spylfs; - private Random random; - private List localDirs; - private Path tokenPath; - private LocalizationProtocol nmProxy; @Test public void testMain() throws Exception { - FileContext fs = FileContext.getLocalFSFileContext(); - spylfs = spy(fs.getDefaultFileSystem()); + ContainerLocalizerWrapper wrapper = new ContainerLocalizerWrapper(); ContainerLocalizer localizer = - setupContainerLocalizerForTest(); + wrapper.setupContainerLocalizerForTest(); + Random random = wrapper.random; + List localDirs = wrapper.localDirs; + Path tokenPath = wrapper.tokenPath; + LocalizationProtocol nmProxy = wrapper.nmProxy; + AbstractFileSystem spylfs = wrapper.spylfs; + mockOutDownloads(localizer); // verify created cache List privCacheList = new ArrayList(); @@ -131,7 +139,7 @@ public class TestContainerLocalizer { ResourceLocalizationSpec rsrcD = getMockRsrc(random, LocalResourceVisibility.PRIVATE, privCacheList.get(0)); - + when(nmProxy.heartbeat(isA(LocalizerStatus.class))) .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, Collections.singletonList(rsrcA))) @@ -202,10 +210,10 @@ public class TestContainerLocalizer { @Test(timeout = 15000) public void testMainFailure() throws Exception { - - FileContext fs = FileContext.getLocalFSFileContext(); - spylfs = spy(fs.getDefaultFileSystem()); - ContainerLocalizer localizer = setupContainerLocalizerForTest(); + ContainerLocalizerWrapper wrapper = new ContainerLocalizerWrapper(); + ContainerLocalizer localizer = wrapper.setupContainerLocalizerForTest(); + LocalizationProtocol nmProxy = wrapper.nmProxy; + mockOutDownloads(localizer); // Assume the NM heartbeat fails say because of absent tokens. when(nmProxy.heartbeat(isA(LocalizerStatus.class))).thenThrow( @@ -223,9 +231,11 @@ public class TestContainerLocalizer { @Test @SuppressWarnings("unchecked") public void testLocalizerTokenIsGettingRemoved() throws Exception { - FileContext fs = FileContext.getLocalFSFileContext(); - spylfs = spy(fs.getDefaultFileSystem()); - ContainerLocalizer localizer = setupContainerLocalizerForTest(); + ContainerLocalizerWrapper wrapper = new ContainerLocalizerWrapper(); + ContainerLocalizer localizer = wrapper.setupContainerLocalizerForTest(); + Path tokenPath = wrapper.tokenPath; + AbstractFileSystem spylfs = wrapper.spylfs; + mockOutDownloads(localizer); doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class), any(CompletionService.class), any(UserGroupInformation.class)); localizer.runLocalization(nmAddr); @@ -237,10 +247,10 @@ public class TestContainerLocalizer { public void testContainerLocalizerClosesFilesystems() throws Exception { // verify filesystems are closed when localizer doesn't fail - FileContext fs = FileContext.getLocalFSFileContext(); - spylfs = spy(fs.getDefaultFileSystem()); + ContainerLocalizerWrapper wrapper = new ContainerLocalizerWrapper(); - ContainerLocalizer localizer = setupContainerLocalizerForTest(); + ContainerLocalizer localizer = wrapper.setupContainerLocalizerForTest(); + mockOutDownloads(localizer); doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class), any(CompletionService.class), any(UserGroupInformation.class)); verify(localizer, never()).closeFileSystems( @@ -249,10 +259,8 @@ public class TestContainerLocalizer { localizer.runLocalization(nmAddr); verify(localizer).closeFileSystems(any(UserGroupInformation.class)); - spylfs = spy(fs.getDefaultFileSystem()); - // verify filesystems are closed when localizer fails - localizer = setupContainerLocalizerForTest(); + localizer = wrapper.setupContainerLocalizerForTest(); doThrow(new YarnRuntimeException("Forced Failure")).when(localizer).localizeFiles( any(LocalizationProtocol.class), any(CompletionService.class), any(UserGroupInformation.class)); @@ -266,41 +274,109 @@ public class TestContainerLocalizer { } } - @SuppressWarnings("unchecked") // mocked generics - private ContainerLocalizer setupContainerLocalizerForTest() - throws Exception { - // don't actually create dirs - doNothing().when(spylfs).mkdir( - isA(Path.class), isA(FsPermission.class), anyBoolean()); + @Test(timeout = 30000) + public void testMultipleLocalizers() throws Exception { + FakeContainerLocalizerWrapper testA = new FakeContainerLocalizerWrapper(); + FakeContainerLocalizerWrapper testB = new FakeContainerLocalizerWrapper(); - Configuration conf = new Configuration(); - FileContext lfs = FileContext.getFileContext(spylfs, conf); - localDirs = new ArrayList(); - for (int i = 0; i < 4; ++i) { - localDirs.add(lfs.makeQualified(new Path(basedir, i + ""))); + final FakeContainerLocalizer localizerA = testA.init(); + final FakeContainerLocalizer localizerB = testB.init(); + + // run localization + Thread threadA = new Thread() { + @Override + public void run() { + try { + localizerA.runLocalization(nmAddr); + } catch (Exception e) { + LOG.warn(e); + } + } + }; + Thread threadB = new Thread() { + @Override + public void run() { + try { + localizerB.runLocalization(nmAddr); + } catch (Exception e) { + LOG.warn(e); + } + } + }; + ShellCommandExecutor shexcA = null; + ShellCommandExecutor shexcB = null; + try { + threadA.start(); + threadB.start(); + + GenericTestUtils.waitFor(new Supplier() { + @Override + public Boolean get() { + FakeContainerLocalizer.FakeLongDownload downloader = + localizerA.getDownloader(); + return downloader != null && downloader.getShexc() != null && + downloader.getShexc().getProcess() != null; + } + }, 10, 30000); + + GenericTestUtils.waitFor(new Supplier() { + @Override + public Boolean get() { + FakeContainerLocalizer.FakeLongDownload downloader = + localizerB.getDownloader(); + return downloader != null && downloader.getShexc() != null && + downloader.getShexc().getProcess() != null; + } + }, 10, 30000); + + shexcA = localizerA.getDownloader().getShexc(); + shexcB = localizerB.getDownloader().getShexc(); + + assertTrue("Localizer A process not running, but should be", + isAlive(shexcA.getProcess())); + assertTrue("Localizer B process not running, but should be", + isAlive(shexcB.getProcess())); + + // Stop heartbeat from giving anymore resources to download + testA.heartbeatResponse++; + testB.heartbeatResponse++; + + // Send DIE to localizerA. This should kill its subprocesses + testA.heartbeatResponse++; + + threadA.join(); + shexcA.getProcess().waitFor(); + + assertFalse("Localizer A process is still running, but shouldn't be", + isAlive(shexcA.getProcess())); + assertTrue("Localizer B process not running, but should be", + isAlive(shexcB.getProcess())); + + } finally { + // Make sure everything gets cleaned up + // Process A should already be dead + shexcA.getProcess().destroy(); + shexcB.getProcess().destroy(); + shexcA.getProcess().waitFor(); + shexcB.getProcess().waitFor(); + + threadA.join(); + // Send DIE to localizer B + testB.heartbeatResponse++; + threadB.join(); } - RecordFactory mockRF = getMockLocalizerRecordFactory(); - ContainerLocalizer concreteLoc = new ContainerLocalizer(lfs, appUser, - appId, containerId, localDirs, mockRF); - ContainerLocalizer localizer = spy(concreteLoc); + } - // return credential stream instead of opening local file - random = new Random(); - long seed = random.nextLong(); - System.out.println("SEED: " + seed); - random.setSeed(seed); - DataInputBuffer appTokens = createFakeCredentials(random, 10); - tokenPath = - lfs.makeQualified(new Path( - String.format(ContainerLocalizer.TOKEN_FILE_NAME_FMT, - containerId))); - doReturn(new FSDataInputStream(new FakeFSDataInputStream(appTokens)) - ).when(spylfs).open(tokenPath); - nmProxy = mock(LocalizationProtocol.class); - doReturn(nmProxy).when(localizer).getProxy(nmAddr); - doNothing().when(localizer).sleep(anyInt()); - + private boolean isAlive(Process p) { + try { + p.exitValue(); + return false; + } catch(IllegalThreadStateException e) { + return true; + } + } + private void mockOutDownloads(ContainerLocalizer localizer) { // return result instantly for deterministic test ExecutorService syncExec = mock(ExecutorService.class); CompletionService cs = mock(CompletionService.class); @@ -318,8 +394,6 @@ public class TestContainerLocalizer { }); doReturn(syncExec).when(localizer).createDownloadThreadPool(); doReturn(cs).when(localizer).createCompletionService(syncExec); - - return localizer; } static class HBMatches extends ArgumentMatcher { @@ -363,6 +437,141 @@ public class TestContainerLocalizer { } } + class FakeContainerLocalizer extends ContainerLocalizer { + private FakeLongDownload downloader; + + FakeContainerLocalizer(FileContext lfs, String user, String appId, + String localizerId, List localDirs, + RecordFactory recordFactory) throws IOException { + super(lfs, user, appId, localizerId, localDirs, recordFactory); + } + + FakeLongDownload getDownloader() { + return downloader; + } + + @Override + Callable download(Path path, LocalResource rsrc, + UserGroupInformation ugi) throws IOException { + downloader = new FakeLongDownload(Mockito.mock(FileContext.class), ugi, + new Configuration(), path, rsrc); + return downloader; + } + + class FakeLongDownload extends ContainerLocalizer.FSDownloadWrapper { + private final Path localPath; + private Shell.ShellCommandExecutor shexc; + FakeLongDownload(FileContext files, UserGroupInformation ugi, + Configuration conf, Path destDirPath, LocalResource resource) { + super(files, ugi, conf, destDirPath, resource); + this.localPath = new Path("file:///localcache"); + } + + Shell.ShellCommandExecutor getShexc() { + return shexc; + } + + @Override + public Path doDownloadCall() throws IOException { + String sleepCommand = "sleep 30"; + String[] shellCmd = {"bash", "-c", sleepCommand}; + shexc = new Shell.ShellCommandExecutor(shellCmd); + shexc.execute(); + + return localPath; + } + } + } + + class ContainerLocalizerWrapper { + AbstractFileSystem spylfs; + Random random; + List localDirs; + Path tokenPath; + LocalizationProtocol nmProxy; + + @SuppressWarnings("unchecked") // mocked generics + FakeContainerLocalizer setupContainerLocalizerForTest() + throws Exception { + + FileContext fs = FileContext.getLocalFSFileContext(); + spylfs = spy(fs.getDefaultFileSystem()); + // don't actually create dirs + doNothing().when(spylfs).mkdir( + isA(Path.class), isA(FsPermission.class), anyBoolean()); + + Configuration conf = new Configuration(); + FileContext lfs = FileContext.getFileContext(spylfs, conf); + localDirs = new ArrayList(); + for (int i = 0; i < 4; ++i) { + localDirs.add(lfs.makeQualified(new Path(basedir, i + ""))); + } + RecordFactory mockRF = getMockLocalizerRecordFactory(); + FakeContainerLocalizer concreteLoc = new FakeContainerLocalizer(lfs, + appUser, appId, containerId, localDirs, mockRF); + FakeContainerLocalizer localizer = spy(concreteLoc); + + // return credential stream instead of opening local file + random = new Random(); + long seed = random.nextLong(); + System.out.println("SEED: " + seed); + random.setSeed(seed); + DataInputBuffer appTokens = createFakeCredentials(random, 10); + tokenPath = + lfs.makeQualified(new Path( + String.format(ContainerLocalizer.TOKEN_FILE_NAME_FMT, + containerId))); + doReturn(new FSDataInputStream(new FakeFSDataInputStream(appTokens)) + ).when(spylfs).open(tokenPath); + nmProxy = mock(LocalizationProtocol.class); + doReturn(nmProxy).when(localizer).getProxy(nmAddr); + doNothing().when(localizer).sleep(anyInt()); + + return localizer; + } + + } + + class FakeContainerLocalizerWrapper extends ContainerLocalizerWrapper{ + private int heartbeatResponse = 0; + public FakeContainerLocalizer init() throws Exception { + FileContext fs = FileContext.getLocalFSFileContext(); + FakeContainerLocalizer localizer = setupContainerLocalizerForTest(); + + // verify created cache + List privCacheList = new ArrayList(); + for (Path p : localDirs) { + Path base = new Path(new Path(p, ContainerLocalizer.USERCACHE), + appUser); + Path privcache = new Path(base, ContainerLocalizer.FILECACHE); + privCacheList.add(privcache); + } + + final ResourceLocalizationSpec rsrc = getMockRsrc(random, + LocalResourceVisibility.PRIVATE, privCacheList.get(0)); + + // mock heartbeat responses from NM + doAnswer(new Answer() { + @Override + public MockLocalizerHeartbeatResponse answer( + InvocationOnMock invocationOnMock) throws Throwable { + if(heartbeatResponse == 0) { + return new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, + Collections.singletonList(rsrc)); + } else if (heartbeatResponse < 2) { + return new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, + Collections.emptyList()); + } else { + return new MockLocalizerHeartbeatResponse(LocalizerAction.DIE, + null); + } + } + }).when(nmProxy).heartbeat(isA(LocalizerStatus.class)); + + return localizer; + } + } + static RecordFactory getMockLocalizerRecordFactory() { RecordFactory mockRF = mock(RecordFactory.class); when(mockRF.newRecordInstance(same(LocalResourceStatus.class)))