YARN-8769. [Submarine] Allow user to specify customized quicklink(s) when submit Submarine job. Contributed by Wangda Tan.

This commit is contained in:
Sunil G 2018-09-21 23:39:22 +05:30
parent a2752779ac
commit 0cd6346102
7 changed files with 303 additions and 35 deletions

View File

@ -49,6 +49,7 @@ public class CliConstants {
public static final String WAIT_JOB_FINISH = "wait_job_finish"; public static final String WAIT_JOB_FINISH = "wait_job_finish";
public static final String PS_DOCKER_IMAGE = "ps_docker_image"; public static final String PS_DOCKER_IMAGE = "ps_docker_image";
public static final String WORKER_DOCKER_IMAGE = "worker_docker_image"; public static final String WORKER_DOCKER_IMAGE = "worker_docker_image";
public static final String QUICKLINK = "quicklink";
public static final String TENSORBOARD_DOCKER_IMAGE = public static final String TENSORBOARD_DOCKER_IMAGE =
"tensorboard_docker_image"; "tensorboard_docker_image";
} }

View File

@ -117,6 +117,14 @@ private Options generateOptions() {
options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true, options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
"Specify docker image for WORKER, when this is not specified, WORKER " "Specify docker image for WORKER, when this is not specified, WORKER "
+ "uses --" + CliConstants.DOCKER_IMAGE + " as default."); + "uses --" + CliConstants.DOCKER_IMAGE + " as default.");
options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN"
+ "web UI shows link to given role instance and port. When "
+ "--tensorboard is speciied, quicklink to tensorboard instance will "
+ "be added automatically. The format of quick link is: "
+ "Quick_link_label=http(or https)://role-name:port. For example, "
+ "if want to link to first worker's 7070 port, and text of quicklink "
+ "is Notebook_UI, user need to specify --quicklink "
+ "Notebook_UI=https://master-0:7070");
options.addOption("h", "help", false, "Print help"); options.addOption("h", "help", false, "Print help");
return options; return options;
} }

View File

@ -0,0 +1,71 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. See accompanying LICENSE file.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param;
import org.apache.commons.cli.ParseException;
/**
* A class represents quick links to a web page.
*/
public class Quicklink {
private String label;
private String componentInstanceName;
private String protocol;
private int port;
public void parse(String quicklinkStr) throws ParseException {
if (!quicklinkStr.contains("=")) {
throw new ParseException("Should be <label>=<link> format for quicklink");
}
int index = quicklinkStr.indexOf("=");
label = quicklinkStr.substring(0, index);
quicklinkStr = quicklinkStr.substring(index + 1);
if (quicklinkStr.startsWith("http://")) {
protocol = "http://";
} else if (quicklinkStr.startsWith("https://")) {
protocol = "https://";
} else {
throw new ParseException("Quicklink should start with http or https");
}
quicklinkStr = quicklinkStr.substring(protocol.length());
index = quicklinkStr.indexOf(":");
if (index == -1) {
throw new ParseException("Quicklink should be componet-id:port form");
}
componentInstanceName = quicklinkStr.substring(0, index);
port = Integer.parseInt(quicklinkStr.substring(index + 1));
}
public String getLabel() {
return label;
}
public String getComponentInstanceName() {
return componentInstanceName;
}
public String getProtocol() {
return protocol;
}
public int getPort() {
return port;
}
}

View File

@ -24,6 +24,8 @@
import org.apache.hadoop.yarn.submarine.common.ClientContext; import org.apache.hadoop.yarn.submarine.common.ClientContext;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/** /**
* Parameters used to run a job * Parameters used to run a job
@ -41,6 +43,7 @@ public class RunJobParameters extends RunParameters {
private String tensorboardDockerImage; private String tensorboardDockerImage;
private String workerLaunchCmd; private String workerLaunchCmd;
private String psLaunchCmd; private String psLaunchCmd;
private List<Quicklink> quicklinks = new ArrayList<>();
private String psDockerImage = null; private String psDockerImage = null;
private String workerDockerImage = null; private String workerDockerImage = null;
@ -119,6 +122,17 @@ public void updateParametersByParsedCommandline(CommandLine parsedCommandLine,
this.waitJobFinish = true; this.waitJobFinish = true;
} }
// Quicklinks
String[] quicklinkStrs = parsedCommandLine.getOptionValues(
CliConstants.QUICKLINK);
if (quicklinkStrs != null) {
for (String ql : quicklinkStrs) {
Quicklink quicklink = new Quicklink();
quicklink.parse(ql);
quicklinks.add(quicklink);
}
}
psDockerImage = parsedCommandLine.getOptionValue( psDockerImage = parsedCommandLine.getOptionValue(
CliConstants.PS_DOCKER_IMAGE); CliConstants.PS_DOCKER_IMAGE);
workerDockerImage = parsedCommandLine.getOptionValue( workerDockerImage = parsedCommandLine.getOptionValue(
@ -247,4 +261,8 @@ public void setTensorboardResource(Resource tensorboardResource) {
public String getTensorboardDockerImage() { public String getTensorboardDockerImage() {
return tensorboardDockerImage; return tensorboardDockerImage;
} }
public List<Quicklink> getQuicklinks() {
return quicklinks;
}
} }

View File

@ -15,7 +15,6 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
@ -29,6 +28,7 @@
import org.apache.hadoop.yarn.service.api.records.ResourceInformation; import org.apache.hadoop.yarn.service.api.records.ResourceInformation;
import org.apache.hadoop.yarn.service.api.records.Service; import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.service.client.ServiceClient; import org.apache.hadoop.yarn.service.client.ServiceClient;
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext; import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.Envs; import org.apache.hadoop.yarn.submarine.common.Envs;
@ -40,10 +40,14 @@
import java.io.File; import java.io.File;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.io.FileWriter; import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.StringTokenizer; import java.util.StringTokenizer;
@ -54,6 +58,7 @@
* Submit a job to cluster * Submit a job to cluster
*/ */
public class YarnServiceJobSubmitter implements JobSubmitter { public class YarnServiceJobSubmitter implements JobSubmitter {
public static final String TENSORBOARD_QUICKLINK_LABEL = "Tensorboard";
private static final Logger LOG = private static final Logger LOG =
LoggerFactory.getLogger(YarnServiceJobSubmitter.class); LoggerFactory.getLogger(YarnServiceJobSubmitter.class);
ClientContext clientContext; ClientContext clientContext;
@ -98,7 +103,7 @@ private String getValueOfEnvionment(String envar) {
} }
private void addHdfsClassPathIfNeeded(RunJobParameters parameters, private void addHdfsClassPathIfNeeded(RunJobParameters parameters,
FileWriter fw, Component comp) throws IOException { PrintWriter fw, Component comp) throws IOException {
// Find envs to use HDFS // Find envs to use HDFS
String hdfsHome = null; String hdfsHome = null;
String javaHome = null; String javaHome = null;
@ -191,7 +196,8 @@ private void addCommonEnvironments(Component component, TaskType taskType) {
envs.put(Envs.TASK_TYPE_ENV, taskType.name()); envs.put(Envs.TASK_TYPE_ENV, taskType.name());
} }
private String getUserName() { @VisibleForTesting
protected String getUserName() {
return System.getProperty("user.name"); return System.getProperty("user.name");
} }
@ -205,18 +211,19 @@ private String getDNSDomain() {
private String generateCommandLaunchScript(RunJobParameters parameters, private String generateCommandLaunchScript(RunJobParameters parameters,
TaskType taskType, Component comp) throws IOException { TaskType taskType, Component comp) throws IOException {
File file = File.createTempFile(taskType.name() + "-launch-script", ".sh"); File file = File.createTempFile(taskType.name() + "-launch-script", ".sh");
FileWriter fw = new FileWriter(file); Writer w = new OutputStreamWriter(new FileOutputStream(file), "UTF-8");
PrintWriter pw = new PrintWriter(w);
try { try {
fw.append("#!/bin/bash\n"); pw.append("#!/bin/bash\n");
addHdfsClassPathIfNeeded(parameters, fw, comp); addHdfsClassPathIfNeeded(parameters, pw, comp);
if (taskType.equals(TaskType.TENSORBOARD)) { if (taskType.equals(TaskType.TENSORBOARD)) {
String tbCommand = String tbCommand =
"export LC_ALL=C && tensorboard --logdir=" + parameters "export LC_ALL=C && tensorboard --logdir=" + parameters
.getCheckpointPath(); .getCheckpointPath();
fw.append(tbCommand + "\n"); pw.append(tbCommand + "\n");
LOG.info("Tensorboard command=" + tbCommand); LOG.info("Tensorboard command=" + tbCommand);
} else{ } else{
// When distributed training is required // When distributed training is required
@ -226,20 +233,20 @@ private String generateCommandLaunchScript(RunJobParameters parameters,
taskType.getComponentName(), parameters.getNumWorkers(), taskType.getComponentName(), parameters.getNumWorkers(),
parameters.getNumPS(), parameters.getName(), getUserName(), parameters.getNumPS(), parameters.getName(), getUserName(),
getDNSDomain()); getDNSDomain());
fw.append("export TF_CONFIG=\"" + tfConfigEnv + "\"\n"); pw.append("export TF_CONFIG=\"" + tfConfigEnv + "\"\n");
} }
// Print launch command // Print launch command
if (taskType.equals(TaskType.WORKER) || taskType.equals( if (taskType.equals(TaskType.WORKER) || taskType.equals(
TaskType.PRIMARY_WORKER)) { TaskType.PRIMARY_WORKER)) {
fw.append(parameters.getWorkerLaunchCmd() + '\n'); pw.append(parameters.getWorkerLaunchCmd() + '\n');
if (SubmarineLogs.isVerbose()) { if (SubmarineLogs.isVerbose()) {
LOG.info( LOG.info(
"Worker command =[" + parameters.getWorkerLaunchCmd() + "]"); "Worker command =[" + parameters.getWorkerLaunchCmd() + "]");
} }
} else if (taskType.equals(TaskType.PS)) { } else if (taskType.equals(TaskType.PS)) {
fw.append(parameters.getPSLaunchCmd() + '\n'); pw.append(parameters.getPSLaunchCmd() + '\n');
if (SubmarineLogs.isVerbose()) { if (SubmarineLogs.isVerbose()) {
LOG.info("PS command =[" + parameters.getPSLaunchCmd() + "]"); LOG.info("PS command =[" + parameters.getPSLaunchCmd() + "]");
@ -247,7 +254,7 @@ private String generateCommandLaunchScript(RunJobParameters parameters,
} }
} }
} finally { } finally {
fw.close(); pw.close();
} }
return file.getAbsolutePath(); return file.getAbsolutePath();
} }
@ -421,18 +428,51 @@ private Artifact getDockerArtifact(String dockerImageName) {
return new Artifact().type(Artifact.TypeEnum.DOCKER).id(dockerImageName); return new Artifact().type(Artifact.TypeEnum.DOCKER).id(dockerImageName);
} }
private void handleQuicklinks(RunJobParameters runJobParameters)
throws IOException {
List<Quicklink> quicklinks = runJobParameters.getQuicklinks();
if (null != quicklinks && !quicklinks.isEmpty()) {
for (Quicklink ql : quicklinks) {
// Make sure it is a valid instance name
String instanceName = ql.getComponentInstanceName();
boolean found = false;
for (Component comp : serviceSpec.getComponents()) {
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
String possibleInstanceName = comp.getName() + "-" + i;
if (possibleInstanceName.equals(instanceName)) {
found = true;
break;
}
}
}
if (!found) {
throw new IOException(
"Couldn't find a component instance = " + instanceName
+ " while adding quicklink");
}
String link = ql.getProtocol() + YarnServiceUtils.getDNSName(
serviceSpec.getName(), instanceName, getUserName(), getDNSDomain(),
ql.getPort());
YarnServiceUtils.addQuicklink(serviceSpec, ql.getLabel(), link);
}
}
}
private Service createServiceByParameters(RunJobParameters parameters) private Service createServiceByParameters(RunJobParameters parameters)
throws IOException { throws IOException {
componentToLocalLaunchScriptPath.clear(); componentToLocalLaunchScriptPath.clear();
Service service = new Service(); serviceSpec = new Service();
service.setName(parameters.getName()); serviceSpec.setName(parameters.getName());
service.setVersion(String.valueOf(System.currentTimeMillis())); serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
service.setArtifact(getDockerArtifact(parameters.getDockerImageName())); serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
handleServiceEnvs(service, parameters); handleServiceEnvs(serviceSpec, parameters);
if (parameters.getNumWorkers() > 0) { if (parameters.getNumWorkers() > 0) {
addWorkerComponents(service, parameters); addWorkerComponents(serviceSpec, parameters);
} }
if (parameters.getNumPS() > 0) { if (parameters.getNumPS() > 0) {
@ -450,7 +490,7 @@ private Service createServiceByParameters(RunJobParameters parameters)
getDockerArtifact(parameters.getPsDockerImage())); getDockerArtifact(parameters.getPsDockerImage()));
} }
handleLaunchCommand(parameters, TaskType.PS, psComponent); handleLaunchCommand(parameters, TaskType.PS, psComponent);
service.addComponent(psComponent); serviceSpec.addComponent(psComponent);
} }
if (parameters.isTensorboardEnabled()) { if (parameters.isTensorboardEnabled()) {
@ -470,14 +510,20 @@ private Service createServiceByParameters(RunJobParameters parameters)
// Add tensorboard to quicklink // Add tensorboard to quicklink
String tensorboardLink = "http://" + YarnServiceUtils.getDNSName( String tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
parameters.getName(), TaskType.TENSORBOARD.getComponentName(), 0, parameters.getName(),
getUserName(), getDNSDomain(), 6006); TaskType.TENSORBOARD.getComponentName() + "-" + 0, getUserName(),
getDNSDomain(), 6006);
LOG.info("Link to tensorboard:" + tensorboardLink); LOG.info("Link to tensorboard:" + tensorboardLink);
service.addComponent(tbComponent); serviceSpec.addComponent(tbComponent);
service.setQuicklinks(ImmutableMap.of("Tensorboard", tensorboardLink));
YarnServiceUtils.addQuicklink(serviceSpec, TENSORBOARD_QUICKLINK_LABEL,
tensorboardLink);
} }
return service; // After all components added, handle quicklinks
handleQuicklinks(parameters);
return serviceSpec;
} }
/** /**
@ -486,12 +532,11 @@ private Service createServiceByParameters(RunJobParameters parameters)
@Override @Override
public ApplicationId submitJob(RunJobParameters parameters) public ApplicationId submitJob(RunJobParameters parameters)
throws IOException, YarnException { throws IOException, YarnException {
Service service = createServiceByParameters(parameters); createServiceByParameters(parameters);
ServiceClient serviceClient = YarnServiceUtils.createServiceClient( ServiceClient serviceClient = YarnServiceUtils.createServiceClient(
clientContext.getYarnConfig()); clientContext.getYarnConfig());
ApplicationId appid = serviceClient.actionCreate(service); ApplicationId appid = serviceClient.actionCreate(serviceSpec);
serviceClient.stop(); serviceClient.stop();
this.serviceSpec = service;
return appid; return appid;
} }

View File

@ -16,10 +16,20 @@
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.service.client.ServiceClient; import org.apache.hadoop.yarn.service.client.ServiceClient;
import org.apache.hadoop.yarn.submarine.common.Envs; import org.apache.hadoop.yarn.submarine.common.Envs;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
public class YarnServiceUtils { public class YarnServiceUtils {
private static final Logger LOG =
LoggerFactory.getLogger(YarnServiceUtils.class);
// This will be true only in UT. // This will be true only in UT.
private static ServiceClient stubServiceClient = null; private static ServiceClient stubServiceClient = null;
@ -40,10 +50,10 @@ public static void setStubServiceClient(ServiceClient stubServiceClient) {
YarnServiceUtils.stubServiceClient = stubServiceClient; YarnServiceUtils.stubServiceClient = stubServiceClient;
} }
public static String getDNSName(String serviceName, String componentName, public static String getDNSName(String serviceName,
int index, String userName, String domain, int port) { String componentInstanceName, String userName, String domain, int port) {
return componentName + "-" + index + getDNSNameCommonSuffix(serviceName, return componentInstanceName + getDNSNameCommonSuffix(serviceName, userName,
userName, domain, port); domain, port);
} }
private static String getDNSNameCommonSuffix(String serviceName, private static String getDNSNameCommonSuffix(String serviceName,
@ -66,12 +76,18 @@ public static String getTFConfigEnv(String curCommponentName, int nWorkers,
commonEndpointSuffix) + ","; commonEndpointSuffix) + ",";
String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},"; String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},";
String task = StringBuilder sb = new StringBuilder();
"\\\"task\\\":{" + " \\\"type\\\":\\\"" + curCommponentName + "\\\"," sb.append("\\\"task\\\":{");
+ " \\\"index\\\":" + '$' + Envs.TASK_INDEX_ENV + "},"; sb.append(" \\\"type\\\":\\\"");
sb.append(curCommponentName);
sb.append("\\\",");
sb.append(" \\\"index\\\":");
sb.append('$');
sb.append(Envs.TASK_INDEX_ENV + "},");
String task = sb.toString();
String environment = "\\\"environment\\\":\\\"cloud\\\"}"; String environment = "\\\"environment\\\":\\\"cloud\\\"}";
StringBuilder sb = new StringBuilder(); sb = new StringBuilder();
sb.append(json); sb.append(json);
sb.append(master); sb.append(master);
sb.append(worker); sb.append(worker);
@ -81,6 +97,21 @@ public static String getTFConfigEnv(String curCommponentName, int nWorkers,
return sb.toString(); return sb.toString();
} }
public static void addQuicklink(Service serviceSpec, String label,
String link) {
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
if (null == quicklinks) {
quicklinks = new HashMap<>();
serviceSpec.setQuicklinks(quicklinks);
}
if (SubmarineLogs.isVerbose()) {
LOG.info("Added quicklink, " + label + "=" + link);
}
quicklinks.put(label, link);
}
private static String getComponentArrayJson(String componentName, int count, private static String getComponentArrayJson(String componentName, int count,
String endpointSuffix) { String endpointSuffix) {
String component = "\\\"" + componentName + "\\\":"; String component = "\\\"" + componentName + "\\\":";

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.submarine.client.cli.yarnservice; package org.apache.hadoop.yarn.submarine.client.cli.yarnservice;
import com.google.common.collect.ImmutableMap;
import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.Resource; import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnException;
@ -100,6 +101,32 @@ private void commonVerifyDistributedTrainingSpec(Service serviceSpec)
Assert.assertTrue(SubmarineLogs.isVerbose()); Assert.assertTrue(SubmarineLogs.isVerbose());
} }
private void verifyQuicklink(Service serviceSpec,
Map<String, String> expectedQuicklinks) {
Map<String, String> actualQuicklinks = serviceSpec.getQuicklinks();
if (actualQuicklinks == null || actualQuicklinks.isEmpty()) {
Assert.assertTrue(
expectedQuicklinks == null || expectedQuicklinks.isEmpty());
return;
}
Assert.assertEquals(expectedQuicklinks.size(), actualQuicklinks.size());
for (Map.Entry<String, String> expectedEntry : expectedQuicklinks
.entrySet()) {
Assert.assertTrue(actualQuicklinks.containsKey(expectedEntry.getKey()));
// $USER could be changed in different environment. so replace $USER by
// "user"
String expectedValue = expectedEntry.getValue();
String actualValue = actualQuicklinks.get(expectedEntry.getKey());
String userName = System.getProperty("user.name");
actualValue = actualValue.replaceAll(userName, "username");
Assert.assertEquals(expectedValue, actualValue);
}
}
@Test @Test
public void testBasicRunJobForDistributedTraining() throws Exception { public void testBasicRunJobForDistributedTraining() throws Exception {
MockClientContext mockClientContext = MockClientContext mockClientContext =
@ -120,6 +147,8 @@ public void testBasicRunJobForDistributedTraining() throws Exception {
Assert.assertEquals(3, serviceSpec.getComponents().size()); Assert.assertEquals(3, serviceSpec.getComponents().size());
commonVerifyDistributedTrainingSpec(serviceSpec); commonVerifyDistributedTrainingSpec(serviceSpec);
verifyQuicklink(serviceSpec, null);
} }
@Test @Test
@ -147,6 +176,10 @@ public void testBasicRunJobForDistributedTrainingWithTensorboard()
verifyTensorboardComponent(runJobCli, serviceSpec, verifyTensorboardComponent(runJobCli, serviceSpec,
Resources.createResource(4096, 1)); Resources.createResource(4096, 1));
verifyQuicklink(serviceSpec, ImmutableMap
.of(YarnServiceJobSubmitter.TENSORBOARD_QUICKLINK_LABEL,
"http://tensorboard-0.my-job.username.null:6006"));
} }
@Test @Test
@ -232,6 +265,9 @@ public void testTensorboardOnlyServiceWithCustomizedDockerImageAndResource()
verifyTensorboardComponent(runJobCli, serviceSpec, verifyTensorboardComponent(runJobCli, serviceSpec,
Resources.createResource(2048, 2)); Resources.createResource(2048, 2));
verifyQuicklink(serviceSpec, ImmutableMap
.of(YarnServiceJobSubmitter.TENSORBOARD_QUICKLINK_LABEL,
"http://tensorboard-0.my-job.username.null:6006"));
} }
private void commonTestSingleNodeTraining(Service serviceSpec) private void commonTestSingleNodeTraining(Service serviceSpec)
@ -372,4 +408,62 @@ public void testParameterStorageForTrainingJob() throws Exception {
Assert.assertEquals(jobInfo.get(StorageKeyConstants.INPUT_PATH), Assert.assertEquals(jobInfo.get(StorageKeyConstants.INPUT_PATH),
"s3://input"); "s3://input");
} }
@Test
public void testAddQuicklinksWithoutTensorboard() throws Exception {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
Assert.assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "s3://input", "--checkpoint_path", "s3://output",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
"ps.image", "--worker_docker_image", "worker.image",
"--ps_launch_cmd", "python run-ps.py", "--verbose", "--quicklink",
"AAA=http://master-0:8321", "--quicklink",
"BBB=http://worker-0:1234" });
Service serviceSpec = getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
Assert.assertEquals(3, serviceSpec.getComponents().size());
commonVerifyDistributedTrainingSpec(serviceSpec);
verifyQuicklink(serviceSpec, ImmutableMap
.of("AAA", "http://master-0.my-job.username.null:8321", "BBB",
"http://worker-0.my-job.username.null:1234"));
}
@Test
public void testAddQuicklinksWithTensorboard() throws Exception {
MockClientContext mockClientContext =
YarnServiceCliTestUtils.getMockClientContext();
RunJobCli runJobCli = new RunJobCli(mockClientContext);
Assert.assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "s3://input", "--checkpoint_path", "s3://output",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image",
"ps.image", "--worker_docker_image", "worker.image",
"--ps_launch_cmd", "python run-ps.py", "--verbose", "--quicklink",
"AAA=http://master-0:8321", "--quicklink",
"BBB=http://worker-0:1234", "--tensorboard" });
Service serviceSpec = getServiceSpecFromJobSubmitter(
runJobCli.getJobSubmitter());
Assert.assertEquals(4, serviceSpec.getComponents().size());
commonVerifyDistributedTrainingSpec(serviceSpec);
verifyQuicklink(serviceSpec, ImmutableMap
.of("AAA", "http://master-0.my-job.username.null:8321", "BBB",
"http://worker-0.my-job.username.null:1234",
YarnServiceJobSubmitter.TENSORBOARD_QUICKLINK_LABEL,
"http://tensorboard-0.my-job.username.null:6006"));
}
} }