YARN-5641. Localizer leaves behind tarballs after container is complete. Contributed by Eric Badger

This commit is contained in:
Jason Lowe 2017-01-27 15:25:57 +00:00
parent 660f4d8631
commit 4703f5d20b
4 changed files with 352 additions and 76 deletions

View File

@ -27,7 +27,9 @@ import java.io.InterruptedIOException;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.Timer; import java.util.Timer;
import java.util.TimerTask; import java.util.TimerTask;
import java.util.WeakHashMap; import java.util.WeakHashMap;
@ -50,8 +52,8 @@ import org.slf4j.LoggerFactory;
@InterfaceAudience.Public @InterfaceAudience.Public
@InterfaceStability.Evolving @InterfaceStability.Evolving
public abstract class Shell { public abstract class Shell {
private static final Map <Process, Object> CHILD_PROCESSES = private static final Map<Shell, Object> CHILD_SHELLS =
Collections.synchronizedMap(new WeakHashMap<Process, Object>()); Collections.synchronizedMap(new WeakHashMap<Shell, Object>());
public static final Logger LOG = LoggerFactory.getLogger(Shell.class); public static final Logger LOG = LoggerFactory.getLogger(Shell.class);
/** /**
@ -820,6 +822,7 @@ public abstract class Shell {
private File dir; private File dir;
private Process process; // sub process used to execute the command private Process process; // sub process used to execute the command
private int exitCode; private int exitCode;
private Thread waitingThread;
/** Flag to indicate whether or not the script has finished executing. */ /** Flag to indicate whether or not the script has finished executing. */
private final AtomicBoolean completed = new AtomicBoolean(false); private final AtomicBoolean completed = new AtomicBoolean(false);
@ -924,7 +927,9 @@ public abstract class Shell {
} else { } else {
process = builder.start(); process = builder.start();
} }
CHILD_PROCESSES.put(process, null);
waitingThread = Thread.currentThread();
CHILD_SHELLS.put(this, null);
if (timeOutInterval > 0) { if (timeOutInterval > 0) {
timeOutTimer = new Timer("Shell command timeout"); timeOutTimer = new Timer("Shell command timeout");
@ -1021,7 +1026,8 @@ public abstract class Shell {
LOG.warn("Error while closing the error stream", ioe); LOG.warn("Error while closing the error stream", ioe);
} }
process.destroy(); process.destroy();
CHILD_PROCESSES.remove(process); waitingThread = null;
CHILD_SHELLS.remove(this);
lastTime = Time.monotonicNow(); lastTime = Time.monotonicNow();
} }
} }
@ -1069,6 +1075,15 @@ public abstract class Shell {
return exitCode; return exitCode;
} }
/** get the thread that is waiting on this instance of <code>Shell</code>.
* @return the thread that ran runCommand() that spawned this shell
* or null if no thread is waiting for this shell to complete
*/
public Thread getWaitingThread() {
return waitingThread;
}
/** /**
* This is an IOException with exit code added. * This is an IOException with exit code added.
*/ */
@ -1322,20 +1337,27 @@ public abstract class Shell {
} }
/** /**
* Static method to destroy all running <code>Shell</code> processes * Static method to destroy all running <code>Shell</code> processes.
* Iterates through a list of all currently running <code>Shell</code> * Iterates through a map of all currently running <code>Shell</code>
* processes and destroys them one by one. This method is thread safe and * processes and destroys them one by one. This method is thread safe
* is intended to be used in a shutdown hook.
*/ */
public static void destroyAllProcesses() { public static void destroyAllShellProcesses() {
synchronized (CHILD_PROCESSES) { synchronized (CHILD_SHELLS) {
for (Process key : CHILD_PROCESSES.keySet()) { for (Shell shell : CHILD_SHELLS.keySet()) {
Process process = key; if (shell.getProcess() != null) {
if (key != null) { shell.getProcess().destroy();
process.destroy();
} }
} }
CHILD_PROCESSES.clear(); CHILD_SHELLS.clear();
}
}
/**
* Static method to return a Set of all <code>Shell</code> objects.
*/
public static Set<Shell> getAllShells() {
synchronized (CHILD_SHELLS) {
return new HashSet<>(CHILD_SHELLS.keySet());
} }
} }
} }

View File

@ -479,7 +479,7 @@ public class TestShell extends Assert {
} }
@Test(timeout=120000) @Test(timeout=120000)
public void testShellKillAllProcesses() throws Throwable { public void testDestroyAllShellProcesses() throws Throwable {
Assume.assumeFalse(WINDOWS); Assume.assumeFalse(WINDOWS);
StringBuffer sleepCommand = new StringBuffer(); StringBuffer sleepCommand = new StringBuffer();
sleepCommand.append("sleep 200"); sleepCommand.append("sleep 200");
@ -524,7 +524,7 @@ public class TestShell extends Assert {
} }
}, 10, 10000); }, 10, 10000);
Shell.destroyAllProcesses(); Shell.destroyAllShellProcesses();
shexc1.getProcess().waitFor(); shexc1.getProcess().waitFor();
shexc2.getProcess().waitFor(); shexc2.getProcess().waitFor();
} }

View File

@ -24,10 +24,13 @@ import java.net.InetSocketAddress;
import java.security.PrivilegedAction; import java.security.PrivilegedAction;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletionService; import java.util.concurrent.CompletionService;
@ -53,6 +56,7 @@ import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.util.DiskValidator; import org.apache.hadoop.util.DiskValidator;
import org.apache.hadoop.util.DiskValidatorFactory; import org.apache.hadoop.util.DiskValidatorFactory;
import org.apache.hadoop.util.Shell;
import org.apache.hadoop.util.concurrent.HadoopExecutors; import org.apache.hadoop.util.concurrent.HadoopExecutors;
import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler; import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler;
import org.apache.hadoop.yarn.api.records.LocalResource; import org.apache.hadoop.yarn.api.records.LocalResource;
@ -75,6 +79,8 @@ import org.apache.hadoop.yarn.util.FSDownload;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.util.concurrent.ThreadFactoryBuilder;
import static org.apache.hadoop.util.Shell.getAllShells;
public class ContainerLocalizer { public class ContainerLocalizer {
static final Log LOG = LogFactory.getLog(ContainerLocalizer.class); static final Log LOG = LogFactory.getLog(ContainerLocalizer.class);
@ -101,6 +107,9 @@ public class ContainerLocalizer {
private final String appCacheDirContextName; private final String appCacheDirContextName;
private final DiskValidator diskValidator; private final DiskValidator diskValidator;
private Set<Thread> localizingThreads =
Collections.synchronizedSet(new HashSet<Thread>());
public ContainerLocalizer(FileContext lfs, String user, String appId, public ContainerLocalizer(FileContext lfs, String user, String appId,
String localizerId, List<Path> localDirs, String localizerId, List<Path> localDirs,
RecordFactory recordFactory) throws IOException { RecordFactory recordFactory) throws IOException {
@ -178,13 +187,14 @@ public class ContainerLocalizer {
exec = createDownloadThreadPool(); exec = createDownloadThreadPool();
CompletionService<Path> ecs = createCompletionService(exec); CompletionService<Path> ecs = createCompletionService(exec);
localizeFiles(nodeManager, ecs, ugi); localizeFiles(nodeManager, ecs, ugi);
return;
} catch (Throwable e) { } catch (Throwable e) {
throw new IOException(e); throw new IOException(e);
} finally { } finally {
try { try {
if (exec != null) { if (exec != null) {
exec.shutdownNow(); exec.shutdown();
destroyShellProcesses(getAllShells());
exec.awaitTermination(10, TimeUnit.SECONDS);
} }
LocalDirAllocator.removeContext(appCacheDirContextName); LocalDirAllocator.removeContext(appCacheDirContextName);
} finally { } finally {
@ -202,10 +212,34 @@ public class ContainerLocalizer {
return new ExecutorCompletionService<Path>(exec); return new ExecutorCompletionService<Path>(exec);
} }
class FSDownloadWrapper extends FSDownload {
FSDownloadWrapper(FileContext files, UserGroupInformation ugi,
Configuration conf, Path destDirPath, LocalResource resource) {
super(files, ugi, conf, destDirPath, resource);
}
@Override
public Path call() throws Exception {
Thread currentThread = Thread.currentThread();
localizingThreads.add(currentThread);
try {
return doDownloadCall();
} finally {
localizingThreads.remove(currentThread);
}
}
Path doDownloadCall() throws Exception {
return super.call();
}
}
Callable<Path> download(Path path, LocalResource rsrc, Callable<Path> download(Path path, LocalResource rsrc,
UserGroupInformation ugi) throws IOException { UserGroupInformation ugi) throws IOException {
diskValidator.checkStatus(new File(path.toUri().getRawPath())); diskValidator.checkStatus(new File(path.toUri().getRawPath()));
return new FSDownload(lfs, ugi, conf, path, rsrc); return new FSDownloadWrapper(lfs, ugi, conf, path, rsrc);
} }
static long getEstimatedSize(LocalResource rsrc) { static long getEstimatedSize(LocalResource rsrc) {
@ -363,6 +397,7 @@ public class ContainerLocalizer {
public static void main(String[] argv) throws Throwable { public static void main(String[] argv) throws Throwable {
Thread.setDefaultUncaughtExceptionHandler(new YarnUncaughtExceptionHandler()); Thread.setDefaultUncaughtExceptionHandler(new YarnUncaughtExceptionHandler());
int nRet = 0;
// usage: $0 user appId locId host port app_log_dir user_dir [user_dir]* // usage: $0 user appId locId host port app_log_dir user_dir [user_dir]*
// let $x = $x/usercache for $local.dir // let $x = $x/usercache for $local.dir
// MKDIR $x/$user/appcache/$appid // MKDIR $x/$user/appcache/$appid
@ -399,7 +434,9 @@ public class ContainerLocalizer {
// space in both DefaultCE and LCE cases // space in both DefaultCE and LCE cases
e.printStackTrace(System.out); e.printStackTrace(System.out);
LOG.error("Exception in main:", e); LOG.error("Exception in main:", e);
System.exit(-1); nRet = -1;
} finally {
System.exit(nRet);
} }
} }
@ -436,4 +473,12 @@ public class ContainerLocalizer {
lfs.setPermission(dirPath, perms); lfs.setPermission(dirPath, perms);
} }
} }
private void destroyShellProcesses(Set<Shell> shells) {
for (Shell shell : shells) {
if(localizingThreads.contains(shell.getWaitingThread())) {
shell.getProcess().destroy();
}
}
}
} }

View File

@ -17,6 +17,8 @@
*/ */
package org.apache.hadoop.yarn.server.nodemanager.containermanager.localizer; package org.apache.hadoop.yarn.server.nodemanager.containermanager.localizer;
import static junit.framework.TestCase.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyBoolean;
@ -25,6 +27,7 @@ import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA; import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
@ -46,6 +49,7 @@ import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import com.google.common.base.Supplier;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
@ -60,6 +64,9 @@ import org.apache.hadoop.io.Text;
import org.apache.hadoop.security.Credentials; import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.util.Shell;
import org.apache.hadoop.util.Shell.ShellCommandExecutor;
import org.apache.hadoop.yarn.api.records.LocalResource; import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.LocalResourceType; import org.apache.hadoop.yarn.api.records.LocalResourceType;
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
@ -76,6 +83,7 @@ import org.apache.hadoop.yarn.util.ConverterUtils;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentMatcher; import org.mockito.ArgumentMatcher;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -92,18 +100,18 @@ public class TestContainerLocalizer {
static final InetSocketAddress nmAddr = static final InetSocketAddress nmAddr =
new InetSocketAddress("foobar", 8040); new InetSocketAddress("foobar", 8040);
private AbstractFileSystem spylfs;
private Random random;
private List<Path> localDirs;
private Path tokenPath;
private LocalizationProtocol nmProxy;
@Test @Test
public void testMain() throws Exception { public void testMain() throws Exception {
FileContext fs = FileContext.getLocalFSFileContext(); ContainerLocalizerWrapper wrapper = new ContainerLocalizerWrapper();
spylfs = spy(fs.getDefaultFileSystem());
ContainerLocalizer localizer = ContainerLocalizer localizer =
setupContainerLocalizerForTest(); wrapper.setupContainerLocalizerForTest();
Random random = wrapper.random;
List<Path> localDirs = wrapper.localDirs;
Path tokenPath = wrapper.tokenPath;
LocalizationProtocol nmProxy = wrapper.nmProxy;
AbstractFileSystem spylfs = wrapper.spylfs;
mockOutDownloads(localizer);
// verify created cache // verify created cache
List<Path> privCacheList = new ArrayList<Path>(); List<Path> privCacheList = new ArrayList<Path>();
@ -131,7 +139,7 @@ public class TestContainerLocalizer {
ResourceLocalizationSpec rsrcD = ResourceLocalizationSpec rsrcD =
getMockRsrc(random, LocalResourceVisibility.PRIVATE, getMockRsrc(random, LocalResourceVisibility.PRIVATE,
privCacheList.get(0)); privCacheList.get(0));
when(nmProxy.heartbeat(isA(LocalizerStatus.class))) when(nmProxy.heartbeat(isA(LocalizerStatus.class)))
.thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE, .thenReturn(new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
Collections.singletonList(rsrcA))) Collections.singletonList(rsrcA)))
@ -202,10 +210,10 @@ public class TestContainerLocalizer {
@Test(timeout = 15000) @Test(timeout = 15000)
public void testMainFailure() throws Exception { public void testMainFailure() throws Exception {
ContainerLocalizerWrapper wrapper = new ContainerLocalizerWrapper();
FileContext fs = FileContext.getLocalFSFileContext(); ContainerLocalizer localizer = wrapper.setupContainerLocalizerForTest();
spylfs = spy(fs.getDefaultFileSystem()); LocalizationProtocol nmProxy = wrapper.nmProxy;
ContainerLocalizer localizer = setupContainerLocalizerForTest(); mockOutDownloads(localizer);
// Assume the NM heartbeat fails say because of absent tokens. // Assume the NM heartbeat fails say because of absent tokens.
when(nmProxy.heartbeat(isA(LocalizerStatus.class))).thenThrow( when(nmProxy.heartbeat(isA(LocalizerStatus.class))).thenThrow(
@ -223,9 +231,11 @@ public class TestContainerLocalizer {
@Test @Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testLocalizerTokenIsGettingRemoved() throws Exception { public void testLocalizerTokenIsGettingRemoved() throws Exception {
FileContext fs = FileContext.getLocalFSFileContext(); ContainerLocalizerWrapper wrapper = new ContainerLocalizerWrapper();
spylfs = spy(fs.getDefaultFileSystem()); ContainerLocalizer localizer = wrapper.setupContainerLocalizerForTest();
ContainerLocalizer localizer = setupContainerLocalizerForTest(); Path tokenPath = wrapper.tokenPath;
AbstractFileSystem spylfs = wrapper.spylfs;
mockOutDownloads(localizer);
doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class), doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class),
any(CompletionService.class), any(UserGroupInformation.class)); any(CompletionService.class), any(UserGroupInformation.class));
localizer.runLocalization(nmAddr); localizer.runLocalization(nmAddr);
@ -237,10 +247,10 @@ public class TestContainerLocalizer {
public void testContainerLocalizerClosesFilesystems() throws Exception { public void testContainerLocalizerClosesFilesystems() throws Exception {
// verify filesystems are closed when localizer doesn't fail // verify filesystems are closed when localizer doesn't fail
FileContext fs = FileContext.getLocalFSFileContext(); ContainerLocalizerWrapper wrapper = new ContainerLocalizerWrapper();
spylfs = spy(fs.getDefaultFileSystem());
ContainerLocalizer localizer = setupContainerLocalizerForTest(); ContainerLocalizer localizer = wrapper.setupContainerLocalizerForTest();
mockOutDownloads(localizer);
doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class), doNothing().when(localizer).localizeFiles(any(LocalizationProtocol.class),
any(CompletionService.class), any(UserGroupInformation.class)); any(CompletionService.class), any(UserGroupInformation.class));
verify(localizer, never()).closeFileSystems( verify(localizer, never()).closeFileSystems(
@ -249,10 +259,8 @@ public class TestContainerLocalizer {
localizer.runLocalization(nmAddr); localizer.runLocalization(nmAddr);
verify(localizer).closeFileSystems(any(UserGroupInformation.class)); verify(localizer).closeFileSystems(any(UserGroupInformation.class));
spylfs = spy(fs.getDefaultFileSystem());
// verify filesystems are closed when localizer fails // verify filesystems are closed when localizer fails
localizer = setupContainerLocalizerForTest(); localizer = wrapper.setupContainerLocalizerForTest();
doThrow(new YarnRuntimeException("Forced Failure")).when(localizer).localizeFiles( doThrow(new YarnRuntimeException("Forced Failure")).when(localizer).localizeFiles(
any(LocalizationProtocol.class), any(CompletionService.class), any(LocalizationProtocol.class), any(CompletionService.class),
any(UserGroupInformation.class)); any(UserGroupInformation.class));
@ -266,41 +274,109 @@ public class TestContainerLocalizer {
} }
} }
@SuppressWarnings("unchecked") // mocked generics @Test(timeout = 30000)
private ContainerLocalizer setupContainerLocalizerForTest() public void testMultipleLocalizers() throws Exception {
throws Exception { FakeContainerLocalizerWrapper testA = new FakeContainerLocalizerWrapper();
// don't actually create dirs FakeContainerLocalizerWrapper testB = new FakeContainerLocalizerWrapper();
doNothing().when(spylfs).mkdir(
isA(Path.class), isA(FsPermission.class), anyBoolean());
Configuration conf = new Configuration(); final FakeContainerLocalizer localizerA = testA.init();
FileContext lfs = FileContext.getFileContext(spylfs, conf); final FakeContainerLocalizer localizerB = testB.init();
localDirs = new ArrayList<Path>();
for (int i = 0; i < 4; ++i) { // run localization
localDirs.add(lfs.makeQualified(new Path(basedir, i + ""))); Thread threadA = new Thread() {
@Override
public void run() {
try {
localizerA.runLocalization(nmAddr);
} catch (Exception e) {
LOG.warn(e);
}
}
};
Thread threadB = new Thread() {
@Override
public void run() {
try {
localizerB.runLocalization(nmAddr);
} catch (Exception e) {
LOG.warn(e);
}
}
};
ShellCommandExecutor shexcA = null;
ShellCommandExecutor shexcB = null;
try {
threadA.start();
threadB.start();
GenericTestUtils.waitFor(new Supplier<Boolean>() {
@Override
public Boolean get() {
FakeContainerLocalizer.FakeLongDownload downloader =
localizerA.getDownloader();
return downloader != null && downloader.getShexc() != null &&
downloader.getShexc().getProcess() != null;
}
}, 10, 30000);
GenericTestUtils.waitFor(new Supplier<Boolean>() {
@Override
public Boolean get() {
FakeContainerLocalizer.FakeLongDownload downloader =
localizerB.getDownloader();
return downloader != null && downloader.getShexc() != null &&
downloader.getShexc().getProcess() != null;
}
}, 10, 30000);
shexcA = localizerA.getDownloader().getShexc();
shexcB = localizerB.getDownloader().getShexc();
assertTrue("Localizer A process not running, but should be",
isAlive(shexcA.getProcess()));
assertTrue("Localizer B process not running, but should be",
isAlive(shexcB.getProcess()));
// Stop heartbeat from giving anymore resources to download
testA.heartbeatResponse++;
testB.heartbeatResponse++;
// Send DIE to localizerA. This should kill its subprocesses
testA.heartbeatResponse++;
threadA.join();
shexcA.getProcess().waitFor();
assertFalse("Localizer A process is still running, but shouldn't be",
isAlive(shexcA.getProcess()));
assertTrue("Localizer B process not running, but should be",
isAlive(shexcB.getProcess()));
} finally {
// Make sure everything gets cleaned up
// Process A should already be dead
shexcA.getProcess().destroy();
shexcB.getProcess().destroy();
shexcA.getProcess().waitFor();
shexcB.getProcess().waitFor();
threadA.join();
// Send DIE to localizer B
testB.heartbeatResponse++;
threadB.join();
} }
RecordFactory mockRF = getMockLocalizerRecordFactory(); }
ContainerLocalizer concreteLoc = new ContainerLocalizer(lfs, appUser,
appId, containerId, localDirs, mockRF);
ContainerLocalizer localizer = spy(concreteLoc);
// return credential stream instead of opening local file private boolean isAlive(Process p) {
random = new Random(); try {
long seed = random.nextLong(); p.exitValue();
System.out.println("SEED: " + seed); return false;
random.setSeed(seed); } catch(IllegalThreadStateException e) {
DataInputBuffer appTokens = createFakeCredentials(random, 10); return true;
tokenPath = }
lfs.makeQualified(new Path( }
String.format(ContainerLocalizer.TOKEN_FILE_NAME_FMT,
containerId)));
doReturn(new FSDataInputStream(new FakeFSDataInputStream(appTokens))
).when(spylfs).open(tokenPath);
nmProxy = mock(LocalizationProtocol.class);
doReturn(nmProxy).when(localizer).getProxy(nmAddr);
doNothing().when(localizer).sleep(anyInt());
private void mockOutDownloads(ContainerLocalizer localizer) {
// return result instantly for deterministic test // return result instantly for deterministic test
ExecutorService syncExec = mock(ExecutorService.class); ExecutorService syncExec = mock(ExecutorService.class);
CompletionService<Path> cs = mock(CompletionService.class); CompletionService<Path> cs = mock(CompletionService.class);
@ -318,8 +394,6 @@ public class TestContainerLocalizer {
}); });
doReturn(syncExec).when(localizer).createDownloadThreadPool(); doReturn(syncExec).when(localizer).createDownloadThreadPool();
doReturn(cs).when(localizer).createCompletionService(syncExec); doReturn(cs).when(localizer).createCompletionService(syncExec);
return localizer;
} }
static class HBMatches extends ArgumentMatcher<LocalizerStatus> { static class HBMatches extends ArgumentMatcher<LocalizerStatus> {
@ -363,6 +437,141 @@ public class TestContainerLocalizer {
} }
} }
class FakeContainerLocalizer extends ContainerLocalizer {
private FakeLongDownload downloader;
FakeContainerLocalizer(FileContext lfs, String user, String appId,
String localizerId, List<Path> localDirs,
RecordFactory recordFactory) throws IOException {
super(lfs, user, appId, localizerId, localDirs, recordFactory);
}
FakeLongDownload getDownloader() {
return downloader;
}
@Override
Callable<Path> download(Path path, LocalResource rsrc,
UserGroupInformation ugi) throws IOException {
downloader = new FakeLongDownload(Mockito.mock(FileContext.class), ugi,
new Configuration(), path, rsrc);
return downloader;
}
class FakeLongDownload extends ContainerLocalizer.FSDownloadWrapper {
private final Path localPath;
private Shell.ShellCommandExecutor shexc;
FakeLongDownload(FileContext files, UserGroupInformation ugi,
Configuration conf, Path destDirPath, LocalResource resource) {
super(files, ugi, conf, destDirPath, resource);
this.localPath = new Path("file:///localcache");
}
Shell.ShellCommandExecutor getShexc() {
return shexc;
}
@Override
public Path doDownloadCall() throws IOException {
String sleepCommand = "sleep 30";
String[] shellCmd = {"bash", "-c", sleepCommand};
shexc = new Shell.ShellCommandExecutor(shellCmd);
shexc.execute();
return localPath;
}
}
}
class ContainerLocalizerWrapper {
AbstractFileSystem spylfs;
Random random;
List<Path> localDirs;
Path tokenPath;
LocalizationProtocol nmProxy;
@SuppressWarnings("unchecked") // mocked generics
FakeContainerLocalizer setupContainerLocalizerForTest()
throws Exception {
FileContext fs = FileContext.getLocalFSFileContext();
spylfs = spy(fs.getDefaultFileSystem());
// don't actually create dirs
doNothing().when(spylfs).mkdir(
isA(Path.class), isA(FsPermission.class), anyBoolean());
Configuration conf = new Configuration();
FileContext lfs = FileContext.getFileContext(spylfs, conf);
localDirs = new ArrayList<Path>();
for (int i = 0; i < 4; ++i) {
localDirs.add(lfs.makeQualified(new Path(basedir, i + "")));
}
RecordFactory mockRF = getMockLocalizerRecordFactory();
FakeContainerLocalizer concreteLoc = new FakeContainerLocalizer(lfs,
appUser, appId, containerId, localDirs, mockRF);
FakeContainerLocalizer localizer = spy(concreteLoc);
// return credential stream instead of opening local file
random = new Random();
long seed = random.nextLong();
System.out.println("SEED: " + seed);
random.setSeed(seed);
DataInputBuffer appTokens = createFakeCredentials(random, 10);
tokenPath =
lfs.makeQualified(new Path(
String.format(ContainerLocalizer.TOKEN_FILE_NAME_FMT,
containerId)));
doReturn(new FSDataInputStream(new FakeFSDataInputStream(appTokens))
).when(spylfs).open(tokenPath);
nmProxy = mock(LocalizationProtocol.class);
doReturn(nmProxy).when(localizer).getProxy(nmAddr);
doNothing().when(localizer).sleep(anyInt());
return localizer;
}
}
class FakeContainerLocalizerWrapper extends ContainerLocalizerWrapper{
private int heartbeatResponse = 0;
public FakeContainerLocalizer init() throws Exception {
FileContext fs = FileContext.getLocalFSFileContext();
FakeContainerLocalizer localizer = setupContainerLocalizerForTest();
// verify created cache
List<Path> privCacheList = new ArrayList<Path>();
for (Path p : localDirs) {
Path base = new Path(new Path(p, ContainerLocalizer.USERCACHE),
appUser);
Path privcache = new Path(base, ContainerLocalizer.FILECACHE);
privCacheList.add(privcache);
}
final ResourceLocalizationSpec rsrc = getMockRsrc(random,
LocalResourceVisibility.PRIVATE, privCacheList.get(0));
// mock heartbeat responses from NM
doAnswer(new Answer<MockLocalizerHeartbeatResponse>() {
@Override
public MockLocalizerHeartbeatResponse answer(
InvocationOnMock invocationOnMock) throws Throwable {
if(heartbeatResponse == 0) {
return new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
Collections.singletonList(rsrc));
} else if (heartbeatResponse < 2) {
return new MockLocalizerHeartbeatResponse(LocalizerAction.LIVE,
Collections.<ResourceLocalizationSpec>emptyList());
} else {
return new MockLocalizerHeartbeatResponse(LocalizerAction.DIE,
null);
}
}
}).when(nmProxy).heartbeat(isA(LocalizerStatus.class));
return localizer;
}
}
static RecordFactory getMockLocalizerRecordFactory() { static RecordFactory getMockLocalizerRecordFactory() {
RecordFactory mockRF = mock(RecordFactory.class); RecordFactory mockRF = mock(RecordFactory.class);
when(mockRF.newRecordInstance(same(LocalResourceStatus.class))) when(mockRF.newRecordInstance(same(LocalResourceStatus.class)))