diff --git a/hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java b/hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java index e7b1e2f2bee..2d91a641410 100644 --- a/hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java +++ b/hadoop-submarine/hadoop-submarine-core/src/main/java/org/apache/hadoop/yarn/submarine/client/cli/param/RunJobParameters.java @@ -293,10 +293,18 @@ public String getPsDockerImage() { return psDockerImage; } + public void setPsDockerImage(String psDockerImage) { + this.psDockerImage = psDockerImage; + } + public String getWorkerDockerImage() { return workerDockerImage; } + public void setWorkerDockerImage(String workerDockerImage) { + this.workerDockerImage = workerDockerImage; + } + public boolean isDistributed() { return distributed; } @@ -313,6 +321,10 @@ public String getTensorboardDockerImage() { return tensorboardDockerImage; } + public void setTensorboardDockerImage(String tensorboardDockerImage) { + this.tensorboardDockerImage = tensorboardDockerImage; + } + public List getQuicklinks() { return quicklinks; } @@ -366,6 +378,10 @@ public RunJobParameters setConfPairs(List confPairs) { return this; } + public void setDistributed(boolean distributed) { + this.distributed = distributed; + } + @VisibleForTesting public static class UnderscoreConverterPropertyUtils extends PropertyUtils { @Override diff --git a/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java b/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java index 4ad0227e09b..d092693a651 100644 --- a/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java +++ b/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/TestRunJobCliParsing.java @@ -177,6 +177,25 @@ public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception { Assert.assertTrue(success); } + @Test + public void testJobWithoutName() throws Exception { + RunJobCli runJobCli = new RunJobCli(getMockClientContext()); + String expectedErrorMessage = + "--" + CliConstants.NAME + " is absent"; + String actualMessage = ""; + try { + runJobCli.run( + new String[]{"--docker_image", "tf-docker:1.1.0", + "--num_workers", "0", "--tensorboard", "--verbose", + "--tensorboard_resources", "memory=2G,vcores=2", + "--tensorboard_docker_image", "tb_docker_image:001"}); + } catch (ParseException e) { + actualMessage = e.getMessage(); + e.printStackTrace(); + } + assertEquals(expectedErrorMessage, actualMessage); + } + @Test public void testLaunchCommandPatternReplace() throws Exception { RunJobCli runJobCli = new RunJobCli(getMockClientContext()); diff --git a/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/common/fs/MockRemoteDirectoryManager.java b/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/common/fs/MockRemoteDirectoryManager.java index 43342932b22..7ef03f53571 100644 --- a/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/common/fs/MockRemoteDirectoryManager.java +++ b/hadoop-submarine/hadoop-submarine-core/src/test/java/org/apache/hadoop/yarn/submarine/common/fs/MockRemoteDirectoryManager.java @@ -26,6 +26,7 @@ import java.io.File; import java.io.IOException; +import java.util.Objects; public class MockRemoteDirectoryManager implements RemoteDirectoryManager { private File jobsParentDir = null; @@ -35,6 +36,7 @@ public class MockRemoteDirectoryManager implements RemoteDirectoryManager { @Override public Path getJobStagingArea(String jobName, boolean create) throws IOException { + Objects.requireNonNull(jobName, "Job name must not be null!"); if (jobsParentDir == null && create) { jobsParentDir = new File( "target/_staging_area_" + System.currentTimeMillis()); diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/pom.xml b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/pom.xml index a337c42fb99..15dffb95e7e 100644 --- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/pom.xml +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/pom.xml @@ -115,6 +115,12 @@ hadoop-yarn-services-core 3.3.0-SNAPSHOT + + org.apache.hadoop + hadoop-yarn-common + test-jar + test + diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java new file mode 100644 index 00000000000..903ae090f0c --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/AbstractComponent.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; + +import java.io.IOException; +import java.util.Objects; + +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName; + +/** + * Abstract base class for Component classes. + * The implementations of this class are act like factories for + * {@link Component} instances. + * All dependencies are passed to the constructor so that child classes + * are obliged to provide matching constructors. + */ +public abstract class AbstractComponent { + private final FileSystemOperations fsOperations; + protected final RunJobParameters parameters; + protected final TaskType taskType; + private final RemoteDirectoryManager remoteDirectoryManager; + protected final Configuration yarnConfig; + private final LaunchCommandFactory launchCommandFactory; + + /** + * This is only required for testing. + */ + private String localScriptFile; + + public AbstractComponent(FileSystemOperations fsOperations, + RemoteDirectoryManager remoteDirectoryManager, + RunJobParameters parameters, TaskType taskType, + Configuration yarnConfig, + LaunchCommandFactory launchCommandFactory) { + this.fsOperations = fsOperations; + this.remoteDirectoryManager = remoteDirectoryManager; + this.parameters = parameters; + this.taskType = taskType; + this.launchCommandFactory = launchCommandFactory; + this.yarnConfig = yarnConfig; + } + + protected abstract Component createComponent() throws IOException; + + /** + * Generates a command launch script on local disk, + * returns path to the script. + */ + protected void generateLaunchCommand(Component component) + throws IOException { + AbstractLaunchCommand launchCommand = + launchCommandFactory.createLaunchCommand(taskType, component); + this.localScriptFile = launchCommand.generateLaunchScript(); + + String remoteLaunchCommand = uploadLaunchCommand(component); + component.setLaunchCommand(remoteLaunchCommand); + } + + private String uploadLaunchCommand(Component component) + throws IOException { + Objects.requireNonNull(localScriptFile, "localScriptFile should be " + + "set before calling this method!"); + Path stagingDir = + remoteDirectoryManager.getJobStagingArea(parameters.getName(), true); + + String destScriptFileName = getScriptFileName(taskType); + fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir, + localScriptFile, destScriptFileName, component); + + return "./" + destScriptFileName; + } + + String getLocalScriptFile() { + return localScriptFile; + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java new file mode 100644 index 00000000000..edac6eda8d0 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/FileSystemOperations.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.service.api.records.ConfigFile; +import org.apache.hadoop.yarn.submarine.common.ClientContext; +import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration; +import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import org.apache.hadoop.yarn.submarine.utils.ZipUtilities; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +/** + * Contains methods to perform file system operations. Almost all of the methods + * are regular non-static methods as the operations are performed with the help + * of a {@link RemoteDirectoryManager} instance passed in as a constructor + * dependency. Please note that some operations require to read config settings + * as well, so that we have Submarine and YARN config objects as dependencies as + * well. + */ +public class FileSystemOperations { + private static final Logger LOG = + LoggerFactory.getLogger(FileSystemOperations.class); + private final Configuration submarineConfig; + private final Configuration yarnConfig; + + private Set uploadedFiles = new HashSet<>(); + private RemoteDirectoryManager remoteDirectoryManager; + + public FileSystemOperations(ClientContext clientContext) { + this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager(); + this.submarineConfig = clientContext.getSubmarineConfig(); + this.yarnConfig = clientContext.getYarnConfig(); + } + + /** + * May download a remote uri(file/dir) and zip. + * Skip download if local dir + * Remote uri can be a local dir(won't download) + * or remote HDFS dir, s3 dir/file .etc + * */ + public String downloadAndZip(String remoteDir, String zipFileName, + boolean doZip) + throws IOException { + //Append original modification time and size to zip file name + String suffix; + String srcDir = remoteDir; + String zipDirPath = + System.getProperty("java.io.tmpdir") + "/" + zipFileName; + boolean needDeleteTempDir = false; + if (remoteDirectoryManager.isRemote(remoteDir)) { + //Append original modification time and size to zip file name + FileStatus status = + remoteDirectoryManager.getRemoteFileStatus(new Path(remoteDir)); + suffix = "_" + status.getModificationTime() + + "-" + remoteDirectoryManager.getRemoteFileSize(remoteDir); + // Download them to temp dir + boolean downloaded = + remoteDirectoryManager.copyRemoteToLocal(remoteDir, zipDirPath); + if (!downloaded) { + throw new IOException("Failed to download files from " + + remoteDir); + } + LOG.info("Downloaded remote: {} to local: {}", remoteDir, zipDirPath); + srcDir = zipDirPath; + needDeleteTempDir = true; + } else { + File localDir = new File(remoteDir); + suffix = "_" + localDir.lastModified() + + "-" + localDir.length(); + } + if (!doZip) { + return srcDir; + } + // zip a local dir + String zipFileUri = + ZipUtilities.zipDir(srcDir, zipDirPath + suffix + ".zip"); + // delete downloaded temp dir + if (needDeleteTempDir) { + deleteFiles(srcDir); + } + return zipFileUri; + } + + public void deleteFiles(String localUri) { + boolean success = FileUtil.fullyDelete(new File(localUri)); + if (!success) { + LOG.warn("Failed to delete {}", localUri); + } + LOG.info("Deleted {}", localUri); + } + + @VisibleForTesting + public void uploadToRemoteFileAndLocalizeToContainerWorkDir(Path stagingDir, + String fileToUpload, String destFilename, Component comp) + throws IOException { + Path uploadedFilePath = uploadToRemoteFile(stagingDir, fileToUpload); + locateRemoteFileToContainerWorkDir(destFilename, comp, uploadedFilePath); + } + + private void locateRemoteFileToContainerWorkDir(String destFilename, + Component comp, Path uploadedFilePath) + throws IOException { + FileSystem fs = FileSystem.get(yarnConfig); + + FileStatus fileStatus = fs.getFileStatus(uploadedFilePath); + LOG.info("Uploaded file path = " + fileStatus.getPath()); + + // Set it to component's files list + comp.getConfiguration().getFiles().add(new ConfigFile().srcFile( + fileStatus.getPath().toUri().toString()).destFile(destFilename) + .type(ConfigFile.TypeEnum.STATIC)); + } + + public Path uploadToRemoteFile(Path stagingDir, String fileToUpload) throws + IOException { + FileSystem fs = remoteDirectoryManager.getDefaultFileSystem(); + + // Upload to remote FS under staging area + File localFile = new File(fileToUpload); + if (!localFile.exists()) { + throw new FileNotFoundException( + "Trying to upload file=" + localFile.getAbsolutePath() + + " to remote, but couldn't find local file."); + } + String filename = new File(fileToUpload).getName(); + + Path uploadedFilePath = new Path(stagingDir, filename); + if (!uploadedFiles.contains(uploadedFilePath)) { + if (SubmarineLogs.isVerbose()) { + LOG.info("Copying local file=" + fileToUpload + " to remote=" + + uploadedFilePath); + } + fs.copyFromLocalFile(new Path(fileToUpload), uploadedFilePath); + uploadedFiles.add(uploadedFilePath); + } + return uploadedFilePath; + } + + public void validFileSize(String uri) throws IOException { + long actualSizeByte; + String locationType = "Local"; + if (remoteDirectoryManager.isRemote(uri)) { + actualSizeByte = remoteDirectoryManager.getRemoteFileSize(uri); + locationType = "Remote"; + } else { + actualSizeByte = FileUtil.getDU(new File(uri)); + } + long maxFileSizeMB = submarineConfig + .getLong(SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB, + SubmarineConfiguration.DEFAULT_MAX_ALLOWED_REMOTE_URI_SIZE_MB); + LOG.info("{} fie/dir: {}, size(Byte):{}," + + " Allowed max file/dir size: {}", + locationType, uri, actualSizeByte, maxFileSizeMB * 1024 * 1024); + + if (actualSizeByte > maxFileSizeMB * 1024 * 1024) { + throw new IOException(uri + " size(Byte): " + + actualSizeByte + " exceeds configured max size:" + + maxFileSizeMB * 1024 * 1024); + } + } + + public void setPermission(Path destPath, FsPermission permission) throws + IOException { + FileSystem fs = FileSystem.get(yarnConfig); + fs.setPermission(destPath, new FsPermission(permission)); + } + + public static boolean needHdfs(String content) { + return content != null && content.contains("hdfs://"); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java new file mode 100644 index 00000000000..461525f3f2c --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/HadoopEnvironmentSetup.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.ClientContext; +import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.io.PrintWriter; + +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs; +import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath; +import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.getValueOfEnvironment; + +/** + * This class contains helper methods to fill HDFS and Java environment + * variables into scripts. + */ +public class HadoopEnvironmentSetup { + private static final Logger LOG = + LoggerFactory.getLogger(HadoopEnvironmentSetup.class); + private static final String CORE_SITE_XML = "core-site.xml"; + private static final String HDFS_SITE_XML = "hdfs-site.xml"; + + public static final String DOCKER_HADOOP_HDFS_HOME = + "DOCKER_HADOOP_HDFS_HOME"; + public static final String DOCKER_JAVA_HOME = "DOCKER_JAVA_HOME"; + private final RemoteDirectoryManager remoteDirectoryManager; + private final FileSystemOperations fsOperations; + + public HadoopEnvironmentSetup(ClientContext clientContext, + FileSystemOperations fsOperations) { + this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager(); + this.fsOperations = fsOperations; + } + + public void addHdfsClassPath(RunJobParameters parameters, + PrintWriter fw, Component comp) throws IOException { + // Find envs to use HDFS + String hdfsHome = null; + String javaHome = null; + + boolean hadoopEnv = false; + + for (String envVar : parameters.getEnvars()) { + if (envVar.startsWith(DOCKER_HADOOP_HDFS_HOME + "=")) { + hdfsHome = getValueOfEnvironment(envVar); + hadoopEnv = true; + } else if (envVar.startsWith(DOCKER_JAVA_HOME + "=")) { + javaHome = getValueOfEnvironment(envVar); + } + } + + boolean hasHdfsEnvs = hdfsHome != null && javaHome != null; + boolean needHdfs = doesNeedHdfs(parameters, hadoopEnv); + if (needHdfs) { + // HDFS is asked either in input or output, set LD_LIBRARY_PATH + // and classpath + if (hdfsHome != null) { + appendHdfsHome(fw, hdfsHome); + } + + // hadoop confs will be uploaded to HDFS and localized to container's + // local folder, so here set $HADOOP_CONF_DIR to $WORK_DIR. + fw.append("export HADOOP_CONF_DIR=$WORK_DIR\n"); + if (javaHome != null) { + appendJavaHome(fw, javaHome); + } + + fw.append( + "export CLASSPATH=`$HADOOP_HDFS_HOME/bin/hadoop classpath --glob`\n"); + } + + if (needHdfs && !hasHdfsEnvs) { + LOG.error("When HDFS is being used to read/write models/data, " + + "the following environment variables are required: " + + "1) {}= " + + "2) {}=. " + + "You can use --env to pass these environment variables.", + DOCKER_HADOOP_HDFS_HOME, DOCKER_JAVA_HOME); + throw new IOException("Failed to detect HDFS-related environments."); + } + + // Trying to upload core-site.xml and hdfs-site.xml + Path stagingDir = + remoteDirectoryManager.getJobStagingArea( + parameters.getName(), true); + File coreSite = findFileOnClassPath(CORE_SITE_XML); + File hdfsSite = findFileOnClassPath(HDFS_SITE_XML); + if (coreSite == null || hdfsSite == null) { + LOG.error("HDFS is being used, however we could not locate " + + "{} nor {} on classpath! " + + "Please double check your classpath setting and make sure these " + + "setting files are included!", CORE_SITE_XML, HDFS_SITE_XML); + throw new IOException( + "Failed to locate core-site.xml / hdfs-site.xml on classpath!"); + } + fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir, + coreSite.getAbsolutePath(), CORE_SITE_XML, comp); + fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir, + hdfsSite.getAbsolutePath(), HDFS_SITE_XML, comp); + + // DEBUG + if (SubmarineLogs.isVerbose()) { + appendEchoOfEnvVars(fw); + } + } + + private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) { + return needHdfs(parameters.getInputPath()) || + needHdfs(parameters.getPSLaunchCmd()) || + needHdfs(parameters.getWorkerLaunchCmd()) || + hadoopEnv; + } + + private void appendHdfsHome(PrintWriter fw, String hdfsHome) { + // Unset HADOOP_HOME/HADOOP_YARN_HOME to make sure host machine's envs + // won't pollute docker's env. + fw.append("export HADOOP_HOME=\n"); + fw.append("export HADOOP_YARN_HOME=\n"); + fw.append("export HADOOP_HDFS_HOME=" + hdfsHome + "\n"); + fw.append("export HADOOP_COMMON_HOME=" + hdfsHome + "\n"); + } + + private void appendJavaHome(PrintWriter fw, String javaHome) { + fw.append("export JAVA_HOME=" + javaHome + "\n"); + fw.append("export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:" + + "$JAVA_HOME/lib/amd64/server\n"); + } + + private void appendEchoOfEnvVars(PrintWriter fw) { + fw.append("echo \"CLASSPATH:$CLASSPATH\"\n"); + fw.append("echo \"HADOOP_CONF_DIR:$HADOOP_CONF_DIR\"\n"); + fw.append( + "echo \"HADOOP_TOKEN_FILE_LOCATION:$HADOOP_TOKEN_FILE_LOCATION\"\n"); + fw.append("echo \"JAVA_HOME:$JAVA_HOME\"\n"); + fw.append("echo \"LD_LIBRARY_PATH:$LD_LIBRARY_PATH\"\n"); + fw.append("echo \"HADOOP_HDFS_HOME:$HADOOP_HDFS_HOME\"\n"); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpec.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpec.java new file mode 100644 index 00000000000..f26d61071c6 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpec.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; + +import java.io.IOException; + +/** + * This interface is to provide means of creating wrappers around + * {@link org.apache.hadoop.yarn.service.api.records.Service} instances. + */ +public interface ServiceSpec { + ServiceWrapper create() throws IOException; +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java new file mode 100644 index 00000000000..06e36d58281 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceSpecFileGenerator.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; + +import org.apache.hadoop.yarn.service.api.records.Service; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.io.Writer; +import java.nio.charset.StandardCharsets; + +import static org.apache.hadoop.yarn.service.utils.ServiceApiUtil.jsonSerDeser; + +/** + * This class is merely responsible for creating Json representation of + * {@link Service} instances. + */ +public final class ServiceSpecFileGenerator { + private ServiceSpecFileGenerator() { + throw new UnsupportedOperationException("This class should not be " + + "instantiated!"); + } + + static String generateJson(Service service) throws IOException { + File serviceSpecFile = File.createTempFile(service.getName(), ".json"); + String buffer = jsonSerDeser.toJson(service); + Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile), + StandardCharsets.UTF_8); + try (PrintWriter pw = new PrintWriter(w)) { + pw.append(buffer); + } + return serviceSpecFile.getAbsolutePath(); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceWrapper.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceWrapper.java new file mode 100644 index 00000000000..3891602c023 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/ServiceWrapper.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Maps; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.service.api.records.Service; + +import java.io.IOException; +import java.util.Map; + +/** + * This class is only existing because we need a component name to + * local launch command mapping from the test code. + * Once this is solved in more clean or different way, we can delete this class. + */ +public class ServiceWrapper { + private final Service service; + + @VisibleForTesting + private Map componentToLocalLaunchCommand = Maps.newHashMap(); + + public ServiceWrapper(Service service) { + this.service = service; + } + + public void addComponent(AbstractComponent abstractComponent) + throws IOException { + Component component = abstractComponent.createComponent(); + service.addComponent(component); + storeComponentName(abstractComponent, component.getName()); + } + + private void storeComponentName( + AbstractComponent component, String name) { + componentToLocalLaunchCommand.put(name, + component.getLocalScriptFile()); + } + + public Service getService() { + return service; + } + + public String getLocalLaunchCommandPathForComponent(String componentName) { + return componentToLocalLaunchCommand.get(componentName); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java index 58a33cf321d..37445a6bfc7 100644 --- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceJobSubmitter.java @@ -15,858 +15,59 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; import com.google.common.annotations.VisibleForTesting; -import org.apache.commons.lang3.StringUtils; -import org.apache.hadoop.fs.FileStatus; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.FileUtil; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.fs.permission.FsPermission; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.client.api.AppAdminClient; import org.apache.hadoop.yarn.exceptions.YarnException; -import org.apache.hadoop.yarn.service.api.ServiceApiConstants; -import org.apache.hadoop.yarn.service.api.records.Artifact; -import org.apache.hadoop.yarn.service.api.records.Component; -import org.apache.hadoop.yarn.service.api.records.ConfigFile; -import org.apache.hadoop.yarn.service.api.records.Resource; -import org.apache.hadoop.yarn.service.api.records.ResourceInformation; import org.apache.hadoop.yarn.service.api.records.Service; -import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal; import org.apache.hadoop.yarn.service.utils.ServiceApiUtil; -import org.apache.hadoop.yarn.submarine.client.cli.param.Localization; -import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink; import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.common.ClientContext; -import org.apache.hadoop.yarn.submarine.common.Envs; -import org.apache.hadoop.yarn.submarine.common.api.TaskType; -import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration; -import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; -import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec; +import org.apache.hadoop.yarn.submarine.utils.Localizer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.FileOutputStream; import java.io.IOException; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.io.Writer; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.StringTokenizer; -import java.util.zip.ZipEntry; -import java.util.zip.ZipOutputStream; -import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION; - -import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants - .CONTAINER_STATE_REPORT_AS_SERVICE_STATE; import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS; -import static org.apache.hadoop.yarn.service.utils.ServiceApiUtil.jsonSerDeser; /** - * Submit a job to cluster + * Submit a job to cluster. */ public class YarnServiceJobSubmitter implements JobSubmitter { - public static final String TENSORBOARD_QUICKLINK_LABEL = "Tensorboard"; + private static final Logger LOG = LoggerFactory.getLogger(YarnServiceJobSubmitter.class); - ClientContext clientContext; - Service serviceSpec; - private Set uploadedFiles = new HashSet<>(); + private ClientContext clientContext; + private ServiceWrapper serviceWrapper; - // Used by testing - private Map componentToLocalLaunchScriptPath = - new HashMap<>(); - - public YarnServiceJobSubmitter(ClientContext clientContext) { + YarnServiceJobSubmitter(ClientContext clientContext) { this.clientContext = clientContext; } - private Resource getServiceResourceFromYarnResource( - org.apache.hadoop.yarn.api.records.Resource yarnResource) { - Resource serviceResource = new Resource(); - serviceResource.setCpus(yarnResource.getVirtualCores()); - serviceResource.setMemory(String.valueOf(yarnResource.getMemorySize())); - - Map riMap = new HashMap<>(); - for (org.apache.hadoop.yarn.api.records.ResourceInformation ri : yarnResource - .getAllResourcesListCopy()) { - ResourceInformation serviceRi = - new ResourceInformation(); - serviceRi.setValue(ri.getValue()); - serviceRi.setUnit(ri.getUnits()); - riMap.put(ri.getName(), serviceRi); - } - serviceResource.setResourceInformations(riMap); - - return serviceResource; - } - - private String getValueOfEnvironment(String envar) { - // extract value from "key=value" form - if (envar == null || !envar.contains("=")) { - return ""; - } else { - return envar.substring(envar.indexOf("=") + 1); - } - } - - private boolean needHdfs(String content) { - return content != null && content.contains("hdfs://"); - } - - private void addHdfsClassPathIfNeeded(RunJobParameters parameters, - PrintWriter fw, Component comp) throws IOException { - // Find envs to use HDFS - String hdfsHome = null; - String javaHome = null; - - boolean hadoopEnv = false; - - for (String envar : parameters.getEnvars()) { - if (envar.startsWith("DOCKER_HADOOP_HDFS_HOME=")) { - hdfsHome = getValueOfEnvironment(envar); - hadoopEnv = true; - } else if (envar.startsWith("DOCKER_JAVA_HOME=")) { - javaHome = getValueOfEnvironment(envar); - } - } - - boolean lackingEnvs = false; - - if (needHdfs(parameters.getInputPath()) || needHdfs( - parameters.getPSLaunchCmd()) || needHdfs( - parameters.getWorkerLaunchCmd()) || hadoopEnv) { - // HDFS is asked either in input or output, set LD_LIBRARY_PATH - // and classpath - if (hdfsHome != null) { - // Unset HADOOP_HOME/HADOOP_YARN_HOME to make sure host machine's envs - // won't pollute docker's env. - fw.append("export HADOOP_HOME=\n"); - fw.append("export HADOOP_YARN_HOME=\n"); - fw.append("export HADOOP_HDFS_HOME=" + hdfsHome + "\n"); - fw.append("export HADOOP_COMMON_HOME=" + hdfsHome + "\n"); - } else{ - lackingEnvs = true; - } - - // hadoop confs will be uploaded to HDFS and localized to container's - // local folder, so here set $HADOOP_CONF_DIR to $WORK_DIR. - fw.append("export HADOOP_CONF_DIR=$WORK_DIR\n"); - if (javaHome != null) { - fw.append("export JAVA_HOME=" + javaHome + "\n"); - fw.append("export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:" - + "$JAVA_HOME/lib/amd64/server\n"); - } else { - lackingEnvs = true; - } - fw.append("export CLASSPATH=`$HADOOP_HDFS_HOME/bin/hadoop classpath --glob`\n"); - } - - if (lackingEnvs) { - LOG.error("When hdfs is being used to read/write models/data. Following" - + "envs are required: 1) DOCKER_HADOOP_HDFS_HOME= 2) DOCKER_JAVA_HOME=. You can use --env to pass these envars."); - throw new IOException("Failed to detect HDFS-related environments."); - } - - // Trying to upload core-site.xml and hdfs-site.xml - Path stagingDir = - clientContext.getRemoteDirectoryManager().getJobStagingArea( - parameters.getName(), true); - File coreSite = findFileOnClassPath("core-site.xml"); - File hdfsSite = findFileOnClassPath("hdfs-site.xml"); - if (coreSite == null || hdfsSite == null) { - LOG.error("hdfs is being used, however we couldn't locate core-site.xml/" - + "hdfs-site.xml from classpath, please double check you classpath" - + "setting and make sure they're included."); - throw new IOException( - "Failed to locate core-site.xml / hdfs-site.xml from class path"); - } - uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir, - coreSite.getAbsolutePath(), "core-site.xml", comp); - uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir, - hdfsSite.getAbsolutePath(), "hdfs-site.xml", comp); - - // DEBUG - if (SubmarineLogs.isVerbose()) { - fw.append("echo \"CLASSPATH:$CLASSPATH\"\n"); - fw.append("echo \"HADOOP_CONF_DIR:$HADOOP_CONF_DIR\"\n"); - fw.append("echo \"HADOOP_TOKEN_FILE_LOCATION:$HADOOP_TOKEN_FILE_LOCATION\"\n"); - fw.append("echo \"JAVA_HOME:$JAVA_HOME\"\n"); - fw.append("echo \"LD_LIBRARY_PATH:$LD_LIBRARY_PATH\"\n"); - fw.append("echo \"HADOOP_HDFS_HOME:$HADOOP_HDFS_HOME\"\n"); - } - } - - private void addCommonEnvironments(Component component, TaskType taskType) { - Map envs = component.getConfiguration().getEnv(); - envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID); - envs.put(Envs.TASK_TYPE_ENV, taskType.name()); - } - - @VisibleForTesting - protected String getUserName() { - return System.getProperty("user.name"); - } - - private String getDNSDomain() { - return clientContext.getYarnConfig().get("hadoop.registry.dns.domain-name"); - } - - /* - * Generate a command launch script on local disk, returns patch to the script - */ - private String generateCommandLaunchScript(RunJobParameters parameters, - TaskType taskType, Component comp) throws IOException { - File file = File.createTempFile(taskType.name() + "-launch-script", ".sh"); - Writer w = new OutputStreamWriter(new FileOutputStream(file), - StandardCharsets.UTF_8); - PrintWriter pw = new PrintWriter(w); - - try { - pw.append("#!/bin/bash\n"); - - addHdfsClassPathIfNeeded(parameters, pw, comp); - - if (taskType.equals(TaskType.TENSORBOARD)) { - String tbCommand = - "export LC_ALL=C && tensorboard --logdir=" + parameters - .getCheckpointPath(); - pw.append(tbCommand + "\n"); - LOG.info("Tensorboard command=" + tbCommand); - } else{ - // When distributed training is required - if (parameters.isDistributed()) { - // Generated TF_CONFIG - String tfConfigEnv = YarnServiceUtils.getTFConfigEnv( - taskType.getComponentName(), parameters.getNumWorkers(), - parameters.getNumPS(), parameters.getName(), getUserName(), - getDNSDomain()); - pw.append("export TF_CONFIG=\"" + tfConfigEnv + "\"\n"); - } - - // Print launch command - if (taskType.equals(TaskType.WORKER) || taskType.equals( - TaskType.PRIMARY_WORKER)) { - pw.append(parameters.getWorkerLaunchCmd() + '\n'); - - if (SubmarineLogs.isVerbose()) { - LOG.info( - "Worker command =[" + parameters.getWorkerLaunchCmd() + "]"); - } - } else if (taskType.equals(TaskType.PS)) { - pw.append(parameters.getPSLaunchCmd() + '\n'); - - if (SubmarineLogs.isVerbose()) { - LOG.info("PS command =[" + parameters.getPSLaunchCmd() + "]"); - } - } - } - } finally { - pw.close(); - } - return file.getAbsolutePath(); - } - - private String getScriptFileName(TaskType taskType) { - return "run-" + taskType.name() + ".sh"; - } - - private File findFileOnClassPath(final String fileName) { - final String classpath = System.getProperty("java.class.path"); - final String pathSeparator = System.getProperty("path.separator"); - final StringTokenizer tokenizer = new StringTokenizer(classpath, - pathSeparator); - - while (tokenizer.hasMoreTokens()) { - final String pathElement = tokenizer.nextToken(); - final File directoryOrJar = new File(pathElement); - final File absoluteDirectoryOrJar = directoryOrJar.getAbsoluteFile(); - if (absoluteDirectoryOrJar.isFile()) { - final File target = new File(absoluteDirectoryOrJar.getParent(), - fileName); - if (target.exists()) { - return target; - } - } else{ - final File target = new File(directoryOrJar, fileName); - if (target.exists()) { - return target; - } - } - } - - return null; - } - - private void uploadToRemoteFileAndLocalizeToContainerWorkDir(Path stagingDir, - String fileToUpload, String destFilename, Component comp) - throws IOException { - Path uploadedFilePath = uploadToRemoteFile(stagingDir, fileToUpload); - locateRemoteFileToContainerWorkDir(destFilename, comp, uploadedFilePath); - } - - private void locateRemoteFileToContainerWorkDir(String destFilename, - Component comp, Path uploadedFilePath) - throws IOException { - FileSystem fs = FileSystem.get(clientContext.getYarnConfig()); - - FileStatus fileStatus = fs.getFileStatus(uploadedFilePath); - LOG.info("Uploaded file path = " + fileStatus.getPath()); - - // Set it to component's files list - comp.getConfiguration().getFiles().add(new ConfigFile().srcFile( - fileStatus.getPath().toUri().toString()).destFile(destFilename) - .type(ConfigFile.TypeEnum.STATIC)); - } - - private Path uploadToRemoteFile(Path stagingDir, String fileToUpload) throws - IOException { - FileSystem fs = clientContext.getRemoteDirectoryManager() - .getDefaultFileSystem(); - - // Upload to remote FS under staging area - File localFile = new File(fileToUpload); - if (!localFile.exists()) { - throw new FileNotFoundException( - "Trying to upload file=" + localFile.getAbsolutePath() - + " to remote, but couldn't find local file."); - } - String filename = new File(fileToUpload).getName(); - - Path uploadedFilePath = new Path(stagingDir, filename); - if (!uploadedFiles.contains(uploadedFilePath)) { - if (SubmarineLogs.isVerbose()) { - LOG.info("Copying local file=" + fileToUpload + " to remote=" - + uploadedFilePath); - } - fs.copyFromLocalFile(new Path(fileToUpload), uploadedFilePath); - uploadedFiles.add(uploadedFilePath); - } - return uploadedFilePath; - } - - private void setPermission(Path destPath, FsPermission permission) throws - IOException { - FileSystem fs = FileSystem.get(clientContext.getYarnConfig()); - fs.setPermission(destPath, new FsPermission(permission)); - } - - private void handleLaunchCommand(RunJobParameters parameters, - TaskType taskType, Component component) throws IOException { - // Get staging area directory - Path stagingDir = - clientContext.getRemoteDirectoryManager().getJobStagingArea( - parameters.getName(), true); - - // Generate script file in the local disk - String localScriptFile = generateCommandLaunchScript(parameters, taskType, - component); - String destScriptFileName = getScriptFileName(taskType); - uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir, localScriptFile, - destScriptFileName, component); - - component.setLaunchCommand("./" + destScriptFileName); - componentToLocalLaunchScriptPath.put(taskType.getComponentName(), - localScriptFile); - } - - private String getLastNameFromPath(String srcFileStr) { - return new Path(srcFileStr).getName(); - } - - /** - * May download a remote uri(file/dir) and zip. - * Skip download if local dir - * Remote uri can be a local dir(won't download) - * or remote HDFS dir, s3 dir/file .etc - * */ - private String mayDownloadAndZipIt(String remoteDir, String zipFileName, - boolean doZip) - throws IOException { - RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager(); - //Append original modification time and size to zip file name - String suffix; - String srcDir = remoteDir; - String zipDirPath = - System.getProperty("java.io.tmpdir") + "/" + zipFileName; - boolean needDeleteTempDir = false; - if (rdm.isRemote(remoteDir)) { - //Append original modification time and size to zip file name - FileStatus status = rdm.getRemoteFileStatus(new Path(remoteDir)); - suffix = "_" + status.getModificationTime() - + "-" + rdm.getRemoteFileSize(remoteDir); - // Download them to temp dir - boolean downloaded = rdm.copyRemoteToLocal(remoteDir, zipDirPath); - if (!downloaded) { - throw new IOException("Failed to download files from " - + remoteDir); - } - LOG.info("Downloaded remote: {} to local: {}", remoteDir, zipDirPath); - srcDir = zipDirPath; - needDeleteTempDir = true; - } else { - File localDir = new File(remoteDir); - suffix = "_" + localDir.lastModified() - + "-" + localDir.length(); - } - if (!doZip) { - return srcDir; - } - // zip a local dir - String zipFileUri = zipDir(srcDir, zipDirPath + suffix + ".zip"); - // delete downloaded temp dir - if (needDeleteTempDir) { - deleteFiles(srcDir); - } - return zipFileUri; - } - - @VisibleForTesting - public String zipDir(String srcDir, String dstFile) throws IOException { - FileOutputStream fos = new FileOutputStream(dstFile); - ZipOutputStream zos = new ZipOutputStream(fos); - File srcFile = new File(srcDir); - LOG.info("Compressing {}", srcDir); - addDirToZip(zos, srcFile, srcFile); - // close the ZipOutputStream - zos.close(); - LOG.info("Compressed {} to {}", srcDir, dstFile); - return dstFile; - } - - private void deleteFiles(String localUri) { - boolean success = FileUtil.fullyDelete(new File(localUri)); - if (!success) { - LOG.warn("Fail to delete {}", localUri); - } - LOG.info("Deleted {}", localUri); - } - - private void addDirToZip(ZipOutputStream zos, File srcFile, File base) - throws IOException { - File[] files = srcFile.listFiles(); - if (null == files) { - return; - } - FileInputStream fis = null; - for (int i = 0; i < files.length; i++) { - // if it's directory, add recursively - if (files[i].isDirectory()) { - addDirToZip(zos, files[i], base); - continue; - } - byte[] buffer = new byte[1024]; - try { - fis = new FileInputStream(files[i]); - String name = base.toURI().relativize(files[i].toURI()).getPath(); - LOG.info(" Zip adding: " + name); - zos.putNextEntry(new ZipEntry(name)); - int length; - while ((length = fis.read(buffer)) > 0) { - zos.write(buffer, 0, length); - } - zos.flush(); - } finally { - if (fis != null) { - fis.close(); - } - zos.closeEntry(); - } - } - } - - private void addWorkerComponent(Service service, - RunJobParameters parameters, TaskType taskType) throws IOException { - Component workerComponent = new Component(); - addCommonEnvironments(workerComponent, taskType); - - workerComponent.setName(taskType.getComponentName()); - - if (taskType.equals(TaskType.PRIMARY_WORKER)) { - workerComponent.setNumberOfContainers(1L); - workerComponent.getConfiguration().setProperty( - CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true"); - } else{ - workerComponent.setNumberOfContainers( - (long) parameters.getNumWorkers() - 1); - } - - if (parameters.getWorkerDockerImage() != null) { - workerComponent.setArtifact( - getDockerArtifact(parameters.getWorkerDockerImage())); - } - - workerComponent.setResource( - getServiceResourceFromYarnResource(parameters.getWorkerResource())); - handleLaunchCommand(parameters, taskType, workerComponent); - workerComponent.setRestartPolicy(Component.RestartPolicyEnum.NEVER); - service.addComponent(workerComponent); - } - - // Handle worker and primary_worker. - private void addWorkerComponents(Service service, RunJobParameters parameters) - throws IOException { - addWorkerComponent(service, parameters, TaskType.PRIMARY_WORKER); - - if (parameters.getNumWorkers() > 1) { - addWorkerComponent(service, parameters, TaskType.WORKER); - } - } - - private void appendToEnv(Service service, String key, String value, - String delim) { - Map env = service.getConfiguration().getEnv(); - if (!env.containsKey(key)) { - env.put(key, value); - } else { - if (!value.isEmpty()) { - String existingValue = env.get(key); - if (!existingValue.endsWith(delim)) { - env.put(key, existingValue + delim + value); - } else { - env.put(key, existingValue + value); - } - } - } - } - - private void handleServiceEnvs(Service service, RunJobParameters parameters) { - if (parameters.getEnvars() != null) { - for (String envarPair : parameters.getEnvars()) { - String key, value; - if (envarPair.contains("=")) { - int idx = envarPair.indexOf('='); - key = envarPair.substring(0, idx); - value = envarPair.substring(idx + 1); - } else{ - // No "=" found so use the whole key - key = envarPair; - value = ""; - } - appendToEnv(service, key, value, ":"); - } - } - - // Append other configs like /etc/passwd, /etc/krb5.conf - appendToEnv(service, "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS", - "/etc/passwd:/etc/passwd:ro", ","); - - String authenication = clientContext.getYarnConfig().get( - HADOOP_SECURITY_AUTHENTICATION); - if (authenication != null && authenication.equals("kerberos")) { - appendToEnv(service, "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS", - "/etc/krb5.conf:/etc/krb5.conf:ro", ","); - } - } - - private Artifact getDockerArtifact(String dockerImageName) { - return new Artifact().type(Artifact.TypeEnum.DOCKER).id(dockerImageName); - } - - private void handleQuicklinks(RunJobParameters runJobParameters) - throws IOException { - List quicklinks = runJobParameters.getQuicklinks(); - if (null != quicklinks && !quicklinks.isEmpty()) { - for (Quicklink ql : quicklinks) { - // Make sure it is a valid instance name - String instanceName = ql.getComponentInstanceName(); - boolean found = false; - - for (Component comp : serviceSpec.getComponents()) { - for (int i = 0; i < comp.getNumberOfContainers(); i++) { - String possibleInstanceName = comp.getName() + "-" + i; - if (possibleInstanceName.equals(instanceName)) { - found = true; - break; - } - } - } - - if (!found) { - throw new IOException( - "Couldn't find a component instance = " + instanceName - + " while adding quicklink"); - } - - String link = ql.getProtocol() + YarnServiceUtils.getDNSName( - serviceSpec.getName(), instanceName, getUserName(), getDNSDomain(), - ql.getPort()); - YarnServiceUtils.addQuicklink(serviceSpec, ql.getLabel(), link); - } - } - } - - private Service createServiceByParameters(RunJobParameters parameters) - throws IOException { - componentToLocalLaunchScriptPath.clear(); - serviceSpec = new Service(); - serviceSpec.setName(parameters.getName()); - serviceSpec.setVersion(String.valueOf(System.currentTimeMillis())); - serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName())); - handleKerberosPrincipal(parameters); - - handleServiceEnvs(serviceSpec, parameters); - - handleLocalizations(parameters); - - if (parameters.getNumWorkers() > 0) { - addWorkerComponents(serviceSpec, parameters); - } - - if (parameters.getNumPS() > 0) { - Component psComponent = new Component(); - psComponent.setName(TaskType.PS.getComponentName()); - addCommonEnvironments(psComponent, TaskType.PS); - psComponent.setNumberOfContainers((long) parameters.getNumPS()); - psComponent.setRestartPolicy(Component.RestartPolicyEnum.NEVER); - psComponent.setResource( - getServiceResourceFromYarnResource(parameters.getPsResource())); - - // Override global docker image if needed. - if (parameters.getPsDockerImage() != null) { - psComponent.setArtifact( - getDockerArtifact(parameters.getPsDockerImage())); - } - handleLaunchCommand(parameters, TaskType.PS, psComponent); - serviceSpec.addComponent(psComponent); - } - - if (parameters.isTensorboardEnabled()) { - Component tbComponent = new Component(); - tbComponent.setName(TaskType.TENSORBOARD.getComponentName()); - addCommonEnvironments(tbComponent, TaskType.TENSORBOARD); - tbComponent.setNumberOfContainers(1L); - tbComponent.setRestartPolicy(Component.RestartPolicyEnum.NEVER); - tbComponent.setResource(getServiceResourceFromYarnResource( - parameters.getTensorboardResource())); - if (parameters.getTensorboardDockerImage() != null) { - tbComponent.setArtifact( - getDockerArtifact(parameters.getTensorboardDockerImage())); - } - - handleLaunchCommand(parameters, TaskType.TENSORBOARD, tbComponent); - - // Add tensorboard to quicklink - String tensorboardLink = "http://" + YarnServiceUtils.getDNSName( - parameters.getName(), - TaskType.TENSORBOARD.getComponentName() + "-" + 0, getUserName(), - getDNSDomain(), 6006); - LOG.info("Link to tensorboard:" + tensorboardLink); - serviceSpec.addComponent(tbComponent); - - YarnServiceUtils.addQuicklink(serviceSpec, TENSORBOARD_QUICKLINK_LABEL, - tensorboardLink); - } - - // After all components added, handle quicklinks - handleQuicklinks(parameters); - - return serviceSpec; - } - - /** - * Localize dependencies for all containers. - * If remoteUri is a local directory, - * we'll zip it, upload to HDFS staging dir HDFS. - * If remoteUri is directory, we'll download it, zip it and upload - * to HDFS. - * If localFilePath is ".", we'll use remoteUri's file/dir name - * */ - private void handleLocalizations(RunJobParameters parameters) - throws IOException { - // Handle localizations - Path stagingDir = - clientContext.getRemoteDirectoryManager().getJobStagingArea( - parameters.getName(), true); - List locs = parameters.getLocalizations(); - String remoteUri; - String containerLocalPath; - RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager(); - - // Check to fail fast - for (Localization loc : locs) { - remoteUri = loc.getRemoteUri(); - Path resourceToLocalize = new Path(remoteUri); - // Check if remoteUri exists - if (rdm.isRemote(remoteUri)) { - // check if exists - if (!rdm.existsRemoteFile(resourceToLocalize)) { - throw new FileNotFoundException( - "File " + remoteUri + " doesn't exists."); - } - } else { - // Check if exists - File localFile = new File(remoteUri); - if (!localFile.exists()) { - throw new FileNotFoundException( - "File " + remoteUri + " doesn't exists."); - } - } - // check remote file size - validFileSize(remoteUri); - } - // Start download remote if needed and upload to HDFS - for (Localization loc : locs) { - remoteUri = loc.getRemoteUri(); - containerLocalPath = loc.getLocalPath(); - String srcFileStr = remoteUri; - ConfigFile.TypeEnum destFileType = ConfigFile.TypeEnum.STATIC; - Path resourceToLocalize = new Path(remoteUri); - boolean needUploadToHDFS = true; - - /** - * Special handling for remoteUri directory. - * */ - boolean needDeleteTempFile = false; - if (rdm.isDir(remoteUri)) { - destFileType = ConfigFile.TypeEnum.ARCHIVE; - srcFileStr = mayDownloadAndZipIt( - remoteUri, getLastNameFromPath(srcFileStr), true); - } else if (rdm.isRemote(remoteUri)) { - if (!needHdfs(remoteUri)) { - // Non HDFS remote uri. Non directory, no need to zip - srcFileStr = mayDownloadAndZipIt( - remoteUri, getLastNameFromPath(srcFileStr), false); - needDeleteTempFile = true; - } else { - // HDFS file, no need to upload - needUploadToHDFS = false; - } - } - - // Upload file to HDFS - if (needUploadToHDFS) { - resourceToLocalize = uploadToRemoteFile(stagingDir, srcFileStr); - } - if (needDeleteTempFile) { - deleteFiles(srcFileStr); - } - // Remove .zip from zipped dir name - if (destFileType == ConfigFile.TypeEnum.ARCHIVE - && srcFileStr.endsWith(".zip")) { - // Delete local zip file - deleteFiles(srcFileStr); - int suffixIndex = srcFileStr.lastIndexOf('_'); - srcFileStr = srcFileStr.substring(0, suffixIndex); - } - // If provided, use the name of local uri - if (!containerLocalPath.equals(".") - && !containerLocalPath.equals("./")) { - // Change the YARN localized file name to what'll used in container - srcFileStr = getLastNameFromPath(containerLocalPath); - } - String localizedName = getLastNameFromPath(srcFileStr); - LOG.info("The file/dir to be localized is {}", - resourceToLocalize.toString()); - LOG.info("Its localized file name will be {}", localizedName); - serviceSpec.getConfiguration().getFiles().add(new ConfigFile().srcFile( - resourceToLocalize.toUri().toString()).destFile(localizedName) - .type(destFileType)); - // set mounts - // if mount path is absolute, just use it. - // if relative, no need to mount explicitly - if (containerLocalPath.startsWith("/")) { - String mountStr = getLastNameFromPath(srcFileStr) + ":" - + containerLocalPath + ":" + loc.getMountPermission(); - LOG.info("Add bind-mount string {}", mountStr); - appendToEnv(serviceSpec, "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS", - mountStr, ","); - } - } - } - - private void validFileSize(String uri) throws IOException { - RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager(); - long actualSizeByte; - String locationType = "Local"; - if (rdm.isRemote(uri)) { - actualSizeByte = clientContext.getRemoteDirectoryManager() - .getRemoteFileSize(uri); - locationType = "Remote"; - } else { - actualSizeByte = FileUtil.getDU(new File(uri)); - } - long maxFileSizeMB = clientContext.getSubmarineConfig() - .getLong(SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB, - SubmarineConfiguration.DEFAULT_MAX_ALLOWED_REMOTE_URI_SIZE_MB); - LOG.info("{} fie/dir: {}, size(Byte):{}," - + " Allowed max file/dir size: {}", - locationType, uri, actualSizeByte, maxFileSizeMB * 1024 * 1024); - - if (actualSizeByte > maxFileSizeMB * 1024 * 1024) { - throw new IOException(uri + " size(Byte): " - + actualSizeByte + " exceeds configured max size:" - + maxFileSizeMB * 1024 * 1024); - } - } - - private String generateServiceSpecFile(Service service) throws IOException { - File serviceSpecFile = File.createTempFile(service.getName(), ".json"); - String buffer = jsonSerDeser.toJson(service); - Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile), - "UTF-8"); - PrintWriter pw = new PrintWriter(w); - try { - pw.append(buffer); - } finally { - pw.close(); - } - return serviceSpecFile.getAbsolutePath(); - } - - private void handleKerberosPrincipal(RunJobParameters parameters) throws - IOException { - if(StringUtils.isNotBlank(parameters.getKeytab()) && StringUtils - .isNotBlank(parameters.getPrincipal())) { - String keytab = parameters.getKeytab(); - String principal = parameters.getPrincipal(); - if(parameters.isDistributeKeytab()) { - Path stagingDir = - clientContext.getRemoteDirectoryManager().getJobStagingArea( - parameters.getName(), true); - Path remoteKeytabPath = uploadToRemoteFile(stagingDir, keytab); - //only the owner has read access - setPermission(remoteKeytabPath, - FsPermission.createImmutable((short)Integer.parseInt("400", 8))); - serviceSpec.setKerberosPrincipal(new KerberosPrincipal().keytab( - remoteKeytabPath.toString()).principalName(principal)); - } else { - if(!keytab.startsWith("file")) { - keytab = "file://" + keytab; - } - serviceSpec.setKerberosPrincipal(new KerberosPrincipal().keytab( - keytab).principalName(principal)); - } - } - } - /** * {@inheritDoc} */ @Override public ApplicationId submitJob(RunJobParameters parameters) throws IOException, YarnException { - createServiceByParameters(parameters); - String serviceSpecFile = generateServiceSpecFile(serviceSpec); + FileSystemOperations fsOperations = new FileSystemOperations(clientContext); + HadoopEnvironmentSetup hadoopEnvSetup = + new HadoopEnvironmentSetup(clientContext, fsOperations); - AppAdminClient appAdminClient = YarnServiceUtils.createServiceClient( - clientContext.getYarnConfig()); + Service serviceSpec = createTensorFlowServiceSpec(parameters, + fsOperations, hadoopEnvSetup); + String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec); + + AppAdminClient appAdminClient = + YarnServiceUtils.createServiceClient(clientContext.getYarnConfig()); int code = appAdminClient.actionLaunch(serviceSpecFile, serviceSpec.getName(), null, null); - if(code != EXIT_SUCCESS) { - throw new YarnException("Fail to launch application with exit code:" + - code); + if (code != EXIT_SUCCESS) { + throw new YarnException( + "Fail to launch application with exit code:" + code); } String appStatus=appAdminClient.getStatusString(serviceSpec.getName()); @@ -896,13 +97,24 @@ public ApplicationId submitJob(RunJobParameters parameters) return appid; } - @VisibleForTesting - public Service getServiceSpec() { - return serviceSpec; + private Service createTensorFlowServiceSpec(RunJobParameters parameters, + FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup) + throws IOException { + LaunchCommandFactory launchCommandFactory = + new LaunchCommandFactory(hadoopEnvSetup, parameters, + clientContext.getYarnConfig()); + Localizer localizer = new Localizer(fsOperations, + clientContext.getRemoteDirectoryManager(), parameters); + TensorFlowServiceSpec tensorFlowServiceSpec = new TensorFlowServiceSpec( + parameters, this.clientContext, fsOperations, launchCommandFactory, + localizer); + + serviceWrapper = tensorFlowServiceSpec.create(); + return serviceWrapper.getService(); } @VisibleForTesting - public Map getComponentToLocalLaunchScriptPath() { - return componentToLocalLaunchScriptPath; + public ServiceWrapper getServiceWrapper() { + return serviceWrapper; } } diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceUtils.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceUtils.java index c599fc9591b..352fd79dedc 100644 --- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceUtils.java +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/YarnServiceUtils.java @@ -17,33 +17,27 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.client.api.AppAdminClient; -import org.apache.hadoop.yarn.service.api.records.Service; -import org.apache.hadoop.yarn.submarine.common.Envs; -import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.HashMap; -import java.util.Map; import static org.apache.hadoop.yarn.client.api.AppAdminClient.DEFAULT_TYPE; -public class YarnServiceUtils { - private static final Logger LOG = - LoggerFactory.getLogger(YarnServiceUtils.class); +/** + * This class contains some static helper methods to query DNS data + * based on the provided parameters. + */ +public final class YarnServiceUtils { + private YarnServiceUtils() { + } // This will be true only in UT. private static AppAdminClient stubServiceClient = null; - public static AppAdminClient createServiceClient( + static AppAdminClient createServiceClient( Configuration yarnConfiguration) { if (stubServiceClient != null) { return stubServiceClient; } - AppAdminClient serviceClient = AppAdminClient.createAppAdminClient( - DEFAULT_TYPE, yarnConfiguration); - return serviceClient; + return AppAdminClient.createAppAdminClient(DEFAULT_TYPE, yarnConfiguration); } @VisibleForTesting @@ -57,77 +51,9 @@ public static String getDNSName(String serviceName, domain, port); } - private static String getDNSNameCommonSuffix(String serviceName, + public static String getDNSNameCommonSuffix(String serviceName, String userName, String domain, int port) { return "." + serviceName + "." + userName + "." + domain + ":" + port; } - public static String getTFConfigEnv(String curCommponentName, int nWorkers, - int nPs, String serviceName, String userName, String domain) { - String commonEndpointSuffix = getDNSNameCommonSuffix(serviceName, userName, - domain, 8000); - - String json = "{\\\"cluster\\\":{"; - - String master = getComponentArrayJson("master", 1, commonEndpointSuffix) - + ","; - String worker = getComponentArrayJson("worker", nWorkers - 1, - commonEndpointSuffix) + ","; - String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},"; - - StringBuilder sb = new StringBuilder(); - sb.append("\\\"task\\\":{"); - sb.append(" \\\"type\\\":\\\""); - sb.append(curCommponentName); - sb.append("\\\","); - sb.append(" \\\"index\\\":"); - sb.append('$'); - sb.append(Envs.TASK_INDEX_ENV + "},"); - String task = sb.toString(); - String environment = "\\\"environment\\\":\\\"cloud\\\"}"; - - sb = new StringBuilder(); - sb.append(json); - sb.append(master); - sb.append(worker); - sb.append(ps); - sb.append(task); - sb.append(environment); - return sb.toString(); - } - - public static void addQuicklink(Service serviceSpec, String label, - String link) { - Map quicklinks = serviceSpec.getQuicklinks(); - if (null == quicklinks) { - quicklinks = new HashMap<>(); - serviceSpec.setQuicklinks(quicklinks); - } - - if (SubmarineLogs.isVerbose()) { - LOG.info("Added quicklink, " + label + "=" + link); - } - - quicklinks.put(label, link); - } - - private static String getComponentArrayJson(String componentName, int count, - String endpointSuffix) { - String component = "\\\"" + componentName + "\\\":"; - StringBuilder array = new StringBuilder(); - array.append("["); - for (int i = 0; i < count; i++) { - array.append("\\\""); - array.append(componentName); - array.append("-"); - array.append(i); - array.append(endpointSuffix); - array.append("\\\""); - if (i != count - 1) { - array.append(","); - } - } - array.append("]"); - return component + array.toString(); - } } diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java new file mode 100644 index 00000000000..cd86e407634 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommand.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command; + +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import java.io.IOException; +import java.util.Objects; + +/** + * Abstract base class for Launch command implementations for Services. + * Currently we have launch command implementations + * for TensorFlow PS, worker and Tensorboard instances. + */ +public abstract class AbstractLaunchCommand { + private final LaunchScriptBuilder builder; + + public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup, + TaskType taskType, Component component, RunJobParameters parameters) + throws IOException { + Objects.requireNonNull(taskType, "TaskType must not be null!"); + this.builder = new LaunchScriptBuilder(taskType.name(), hadoopEnvSetup, + parameters, component); + } + + protected LaunchScriptBuilder getBuilder() { + return builder; + } + + /** + * Subclasses need to defined this method and return a valid launch script. + * Implementors can utilize the {@link LaunchScriptBuilder} using + * the getBuilder method of this class. + * @return The contents of a script. + * @throws IOException If any IO issue happens. + */ + public abstract String generateLaunchScript() throws IOException; + + /** + * Subclasses need to provide a service-specific launch command + * of the service. + * Please note that this method should only return the launch command + * but not the whole script. + * @return The launch command + */ + public abstract String createLaunchCommand(); + +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java new file mode 100644 index 00000000000..572e65a4e1e --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchCommandFactory.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand; + +import java.io.IOException; +import java.util.Objects; + +/** + * Simple factory to create instances of {@link AbstractLaunchCommand} + * based on the {@link TaskType}. + * All dependencies are passed to this factory that could be required + * by any implementor of {@link AbstractLaunchCommand}. + */ +public class LaunchCommandFactory { + private final HadoopEnvironmentSetup hadoopEnvSetup; + private final RunJobParameters parameters; + private final Configuration yarnConfig; + + public LaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup, + RunJobParameters parameters, Configuration yarnConfig) { + this.hadoopEnvSetup = hadoopEnvSetup; + this.parameters = parameters; + this.yarnConfig = yarnConfig; + } + + public AbstractLaunchCommand createLaunchCommand(TaskType taskType, + Component component) throws IOException { + Objects.requireNonNull(taskType, "TaskType must not be null!"); + + if (taskType == TaskType.WORKER || taskType == TaskType.PRIMARY_WORKER) { + return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType, + component, parameters, yarnConfig); + + } else if (taskType == TaskType.PS) { + return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component, + parameters, yarnConfig); + + } else if (taskType == TaskType.TENSORBOARD) { + return new TensorBoardLaunchCommand(hadoopEnvSetup, taskType, component, + parameters); + } + throw new IllegalStateException("Unknown task type: " + taskType); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java new file mode 100644 index 00000000000..d24a0a772e1 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/LaunchScriptBuilder.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command; + +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; + +import static java.nio.charset.StandardCharsets.UTF_8; + +/** + * This class is a builder to conveniently create launch scripts. + * All dependencies are provided with the constructor except + * the launch command. + */ +public class LaunchScriptBuilder { + private static final Logger LOG = LoggerFactory.getLogger( + LaunchScriptBuilder.class); + + private final File file; + private final HadoopEnvironmentSetup hadoopEnvSetup; + private final RunJobParameters parameters; + private final Component component; + private final OutputStreamWriter writer; + private final StringBuilder scriptBuffer; + private String launchCommand; + + LaunchScriptBuilder(String namePrefix, + HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters, + Component component) throws IOException { + this.file = File.createTempFile(namePrefix + "-launch-script", ".sh"); + this.hadoopEnvSetup = hadoopEnvSetup; + this.parameters = parameters; + this.component = component; + this.writer = new OutputStreamWriter(new FileOutputStream(file), UTF_8); + this.scriptBuffer = new StringBuilder(); + } + + public void append(String s) { + scriptBuffer.append(s); + } + + public LaunchScriptBuilder withLaunchCommand(String command) { + this.launchCommand = command; + return this; + } + + public String build() throws IOException { + if (launchCommand != null) { + append(launchCommand); + } else { + LOG.warn("LaunchScript object was null!"); + if (LOG.isDebugEnabled()) { + LOG.debug("LaunchScript's Builder object: {}", this); + } + } + + try (PrintWriter pw = new PrintWriter(writer)) { + writeBashHeader(pw); + hadoopEnvSetup.addHdfsClassPath(parameters, pw, component); + if (LOG.isDebugEnabled()) { + LOG.debug("Appending command to launch script: {}", scriptBuffer); + } + pw.append(scriptBuffer); + } + return file.getAbsolutePath(); + } + + @Override + public String toString() { + return "LaunchScriptBuilder{" + + "file=" + file + + ", hadoopEnvSetup=" + hadoopEnvSetup + + ", parameters=" + parameters + + ", component=" + component + + ", writer=" + writer + + ", scriptBuffer=" + scriptBuffer + + ", launchCommand='" + launchCommand + '\'' + + '}'; + } + + private void writeBashHeader(PrintWriter pw) { + pw.append("#!/bin/bash\n"); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/package-info.java new file mode 100644 index 00000000000..a2572047e3a --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * This package contains classes to produce launch commands and scripts. + */ +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command; \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java new file mode 100644 index 00000000000..ea735c977f1 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowCommons.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.ServiceApiConstants; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.common.Envs; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils; + +import java.util.Map; + +/** + * This class has common helper methods for TensorFlow. + */ +public final class TensorFlowCommons { + private TensorFlowCommons() { + throw new UnsupportedOperationException("This class should not be " + + "instantiated!"); + } + + public static void addCommonEnvironments(Component component, + TaskType taskType) { + Map envs = component.getConfiguration().getEnv(); + envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID); + envs.put(Envs.TASK_TYPE_ENV, taskType.name()); + } + + public static String getUserName() { + return System.getProperty("user.name"); + } + + public static String getDNSDomain(Configuration yarnConfig) { + return yarnConfig.get("hadoop.registry.dns.domain-name"); + } + + public static String getScriptFileName(TaskType taskType) { + return "run-" + taskType.name() + ".sh"; + } + + public static String getTFConfigEnv(String componentName, int nWorkers, + int nPs, String serviceName, String userName, String domain) { + String commonEndpointSuffix = YarnServiceUtils + .getDNSNameCommonSuffix(serviceName, userName, domain, 8000); + + String json = "{\\\"cluster\\\":{"; + + String master = getComponentArrayJson("master", 1, commonEndpointSuffix) + + ","; + String worker = getComponentArrayJson("worker", nWorkers - 1, + commonEndpointSuffix) + ","; + String ps = getComponentArrayJson("ps", nPs, commonEndpointSuffix) + "},"; + + StringBuilder sb = new StringBuilder(); + sb.append("\\\"task\\\":{"); + sb.append(" \\\"type\\\":\\\""); + sb.append(componentName); + sb.append("\\\","); + sb.append(" \\\"index\\\":"); + sb.append('$'); + sb.append(Envs.TASK_INDEX_ENV + "},"); + String task = sb.toString(); + String environment = "\\\"environment\\\":\\\"cloud\\\"}"; + + sb = new StringBuilder(); + sb.append(json); + sb.append(master); + sb.append(worker); + sb.append(ps); + sb.append(task); + sb.append(environment); + return sb.toString(); + } + + private static String getComponentArrayJson(String componentName, int count, + String endpointSuffix) { + String component = "\\\"" + componentName + "\\\":"; + StringBuilder array = new StringBuilder(); + array.append("["); + for (int i = 0; i < count; i++) { + array.append("\\\""); + array.append(componentName); + array.append("-"); + array.append(i); + array.append(endpointSuffix); + array.append("\\\""); + if (i != count - 1) { + array.append(","); + } + } + array.append("]"); + return component + array.toString(); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java new file mode 100644 index 00000000000..815a41acbf2 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/TensorFlowServiceSpec.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal; +import org.apache.hadoop.yarn.service.api.records.Service; +import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.ClientContext; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpec; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent; +import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory; +import org.apache.hadoop.yarn.submarine.utils.Localizer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain; +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName; +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL; +import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact; +import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs; + +/** + * This class contains all the logic to create an instance + * of a {@link Service} object for TensorFlow. + * Worker,PS and Tensorboard components are added to the Service + * based on the value of the received {@link RunJobParameters}. + */ +public class TensorFlowServiceSpec implements ServiceSpec { + private static final Logger LOG = + LoggerFactory.getLogger(TensorFlowServiceSpec.class); + + private final RemoteDirectoryManager remoteDirectoryManager; + + private final RunJobParameters parameters; + private final Configuration yarnConfig; + private final FileSystemOperations fsOperations; + private final LaunchCommandFactory launchCommandFactory; + private final Localizer localizer; + + public TensorFlowServiceSpec(RunJobParameters parameters, + ClientContext clientContext, FileSystemOperations fsOperations, + LaunchCommandFactory launchCommandFactory, Localizer localizer) { + this.parameters = parameters; + this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager(); + this.yarnConfig = clientContext.getYarnConfig(); + this.fsOperations = fsOperations; + this.launchCommandFactory = launchCommandFactory; + this.localizer = localizer; + } + + @Override + public ServiceWrapper create() throws IOException { + ServiceWrapper serviceWrapper = createServiceSpecWrapper(); + + if (parameters.getNumWorkers() > 0) { + addWorkerComponents(serviceWrapper); + } + + if (parameters.getNumPS() > 0) { + addPsComponent(serviceWrapper); + } + + if (parameters.isTensorboardEnabled()) { + createTensorBoardComponent(serviceWrapper); + } + + // After all components added, handle quicklinks + handleQuicklinks(serviceWrapper.getService()); + + return serviceWrapper; + } + + private ServiceWrapper createServiceSpecWrapper() throws IOException { + Service serviceSpec = new Service(); + serviceSpec.setName(parameters.getName()); + serviceSpec.setVersion(String.valueOf(System.currentTimeMillis())); + serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName())); + + KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory + .create(fsOperations, remoteDirectoryManager, parameters); + if (kerberosPrincipal != null) { + serviceSpec.setKerberosPrincipal(kerberosPrincipal); + } + + handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars()); + localizer.handleLocalizations(serviceSpec); + return new ServiceWrapper(serviceSpec); + } + + private void createTensorBoardComponent(ServiceWrapper serviceWrapper) + throws IOException { + TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations, + remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig); + serviceWrapper.addComponent(tbComponent); + + addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL, + tbComponent.getTensorboardLink()); + } + + private static void addQuicklink(Service serviceSpec, String label, + String link) { + Map quicklinks = serviceSpec.getQuicklinks(); + if (quicklinks == null) { + quicklinks = new HashMap<>(); + serviceSpec.setQuicklinks(quicklinks); + } + + if (SubmarineLogs.isVerbose()) { + LOG.info("Added quicklink, " + label + "=" + link); + } + + quicklinks.put(label, link); + } + + private void handleQuicklinks(Service serviceSpec) + throws IOException { + List quicklinks = parameters.getQuicklinks(); + if (quicklinks != null && !quicklinks.isEmpty()) { + for (Quicklink ql : quicklinks) { + // Make sure it is a valid instance name + String instanceName = ql.getComponentInstanceName(); + boolean found = false; + + for (Component comp : serviceSpec.getComponents()) { + for (int i = 0; i < comp.getNumberOfContainers(); i++) { + String possibleInstanceName = comp.getName() + "-" + i; + if (possibleInstanceName.equals(instanceName)) { + found = true; + break; + } + } + } + + if (!found) { + throw new IOException( + "Couldn't find a component instance = " + instanceName + + " while adding quicklink"); + } + + String link = ql.getProtocol() + + YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName, + getUserName(), getDNSDomain(yarnConfig), ql.getPort()); + addQuicklink(serviceSpec, ql.getLabel(), link); + } + } + } + + // Handle worker and primary_worker. + + private void addWorkerComponents(ServiceWrapper serviceWrapper) + throws IOException { + addWorkerComponent(serviceWrapper, parameters, TaskType.PRIMARY_WORKER); + + if (parameters.getNumWorkers() > 1) { + addWorkerComponent(serviceWrapper, parameters, TaskType.WORKER); + } + } + private void addWorkerComponent(ServiceWrapper serviceWrapper, + RunJobParameters parameters, TaskType taskType) throws IOException { + serviceWrapper.addComponent( + new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager, + parameters, taskType, launchCommandFactory, yarnConfig)); + } + + private void addPsComponent(ServiceWrapper serviceWrapper) + throws IOException { + serviceWrapper.addComponent( + new TensorFlowPsComponent(fsOperations, remoteDirectoryManager, + launchCommandFactory, parameters, yarnConfig)); + } + +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java new file mode 100644 index 00000000000..dcd45c0f66a --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorBoardLaunchCommand.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Objects; + +/** + * Launch command implementation for Tensorboard. + */ +public class TensorBoardLaunchCommand extends AbstractLaunchCommand { + private static final Logger LOG = + LoggerFactory.getLogger(TensorBoardLaunchCommand.class); + private final String checkpointPath; + + public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup, + TaskType taskType, Component component, RunJobParameters parameters) + throws IOException { + super(hadoopEnvSetup, taskType, component, parameters); + Objects.requireNonNull(parameters.getCheckpointPath(), + "CheckpointPath must not be null as it is part " + + "of the tensorboard command!"); + if (StringUtils.isEmpty(parameters.getCheckpointPath())) { + throw new IllegalArgumentException("CheckpointPath must not be empty!"); + } + + this.checkpointPath = parameters.getCheckpointPath(); + } + + @Override + public String generateLaunchScript() throws IOException { + return getBuilder() + .withLaunchCommand(createLaunchCommand()) + .build(); + } + + @Override + public String createLaunchCommand() { + String tbCommand = String.format("export LC_ALL=C && tensorboard " + + "--logdir=%s%n", checkpointPath); + LOG.info("Tensorboard command=" + tbCommand); + return tbCommand; + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java new file mode 100644 index 00000000000..07a18113145 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowLaunchCommand.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Launch command implementation for + * TensorFlow PS and Worker Service components. + */ +public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand { + private static final Logger LOG = + LoggerFactory.getLogger(TensorFlowLaunchCommand.class); + private final Configuration yarnConfig; + private final boolean distributed; + private final int numberOfWorkers; + private final int numberOfPS; + private final String name; + private final TaskType taskType; + + TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup, + TaskType taskType, Component component, RunJobParameters parameters, + Configuration yarnConfig) throws IOException { + super(hadoopEnvSetup, taskType, component, parameters); + this.taskType = taskType; + this.name = parameters.getName(); + this.distributed = parameters.isDistributed(); + this.numberOfWorkers = parameters.getNumWorkers(); + this.numberOfPS = parameters.getNumPS(); + this.yarnConfig = yarnConfig; + logReceivedParameters(); + } + + private void logReceivedParameters() { + if (this.numberOfWorkers <= 0) { + LOG.warn("Received number of workers: {}", this.numberOfWorkers); + } + if (this.numberOfPS <= 0) { + LOG.warn("Received number of PS: {}", this.numberOfPS); + } + } + + @Override + public String generateLaunchScript() throws IOException { + LaunchScriptBuilder builder = getBuilder(); + + // When distributed training is required + if (distributed) { + String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv( + taskType.getComponentName(), numberOfWorkers, + numberOfPS, name, + TensorFlowCommons.getUserName(), + TensorFlowCommons.getDNSDomain(yarnConfig)); + String tfConfig = "export TF_CONFIG=\"" + tfConfigEnvValue + "\"\n"; + builder.append(tfConfig); + } + + return builder + .withLaunchCommand(createLaunchCommand()) + .build(); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java new file mode 100644 index 00000000000..e1aca40e931 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowPsLaunchCommand.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Launch command implementation for Tensorboard's PS component. + */ +public class TensorFlowPsLaunchCommand extends TensorFlowLaunchCommand { + private static final Logger LOG = + LoggerFactory.getLogger(TensorFlowPsLaunchCommand.class); + private final String launchCommand; + + public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup, + TaskType taskType, Component component, RunJobParameters parameters, + Configuration yarnConfig) throws IOException { + super(hadoopEnvSetup, taskType, component, parameters, yarnConfig); + this.launchCommand = parameters.getPSLaunchCmd(); + + if (StringUtils.isEmpty(this.launchCommand)) { + throw new IllegalArgumentException("LaunchCommand must not be null " + + "or empty!"); + } + } + + @Override + public String createLaunchCommand() { + if (SubmarineLogs.isVerbose()) { + LOG.info("PS command =[" + launchCommand + "]"); + } + return launchCommand + '\n'; + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java new file mode 100644 index 00000000000..734d8799c55 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TensorFlowWorkerLaunchCommand.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Launch command implementation for Tensorboard's Worker component. + */ +public class TensorFlowWorkerLaunchCommand extends TensorFlowLaunchCommand { + private static final Logger LOG = + LoggerFactory.getLogger(TensorFlowWorkerLaunchCommand.class); + private final String launchCommand; + + public TensorFlowWorkerLaunchCommand( + HadoopEnvironmentSetup hadoopEnvSetup, TaskType taskType, + Component component, RunJobParameters parameters, + Configuration yarnConfig) throws IOException { + super(hadoopEnvSetup, taskType, component, parameters, yarnConfig); + this.launchCommand = parameters.getWorkerLaunchCmd(); + + if (StringUtils.isEmpty(this.launchCommand)) { + throw new IllegalArgumentException("LaunchCommand must not be null " + + "or empty!"); + } + } + + @Override + public String createLaunchCommand() { + if (SubmarineLogs.isVerbose()) { + LOG.info("Worker command =[" + launchCommand + "]"); + } + return launchCommand + '\n'; + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/package-info.java new file mode 100644 index 00000000000..f8df3bbf123 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * This package contains classes to generate TensorFlow launch commands. + */ +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command; \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java new file mode 100644 index 00000000000..2b9c1ca2551 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorBoardComponent.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Objects; + +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments; +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain; +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName; +import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact; +import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource; + +/** + * Component implementation for Tensorboard's Tensorboard. + */ +public class TensorBoardComponent extends AbstractComponent { + private static final Logger LOG = + LoggerFactory.getLogger(TensorBoardComponent.class); + + public static final String TENSORBOARD_QUICKLINK_LABEL = "Tensorboard"; + private static final int DEFAULT_PORT = 6006; + + //computed fields + private String tensorboardLink; + + public TensorBoardComponent(FileSystemOperations fsOperations, + RemoteDirectoryManager remoteDirectoryManager, + RunJobParameters parameters, + LaunchCommandFactory launchCommandFactory, + Configuration yarnConfig) { + super(fsOperations, remoteDirectoryManager, parameters, + TaskType.TENSORBOARD, yarnConfig, launchCommandFactory); + } + + @Override + public Component createComponent() throws IOException { + Objects.requireNonNull(parameters.getTensorboardResource(), + "TensorBoard resource must not be null!"); + + Component component = new Component(); + component.setName(taskType.getComponentName()); + component.setNumberOfContainers(1L); + component.setRestartPolicy(RestartPolicyEnum.NEVER); + component.setResource(convertYarnResourceToServiceResource( + parameters.getTensorboardResource())); + + if (parameters.getTensorboardDockerImage() != null) { + component.setArtifact( + getDockerArtifact(parameters.getTensorboardDockerImage())); + } + + addCommonEnvironments(component, taskType); + generateLaunchCommand(component); + + tensorboardLink = "http://" + YarnServiceUtils.getDNSName( + parameters.getName(), + taskType.getComponentName() + "-" + 0, getUserName(), + getDNSDomain(yarnConfig), DEFAULT_PORT); + LOG.info("Link to tensorboard:" + tensorboardLink); + + return component; + } + + public String getTensorboardLink() { + return tensorboardLink; + } + +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowPsComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowPsComponent.java new file mode 100644 index 00000000000..c70e1328592 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowPsComponent.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; + +import java.io.IOException; +import java.util.Objects; + +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments; +import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact; +import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource; + +/** + * Component implementation for TensorFlow's PS process. + */ +public class TensorFlowPsComponent extends AbstractComponent { + public TensorFlowPsComponent(FileSystemOperations fsOperations, + RemoteDirectoryManager remoteDirectoryManager, + LaunchCommandFactory launchCommandFactory, + RunJobParameters parameters, + Configuration yarnConfig) { + super(fsOperations, remoteDirectoryManager, parameters, TaskType.PS, + yarnConfig, launchCommandFactory); + } + + @Override + public Component createComponent() throws IOException { + Objects.requireNonNull(parameters.getPsResource(), + "PS resource must not be null!"); + if (parameters.getNumPS() < 1) { + throw new IllegalArgumentException("Number of PS should be at least 1!"); + } + + Component component = new Component(); + component.setName(taskType.getComponentName()); + component.setNumberOfContainers((long) parameters.getNumPS()); + component.setRestartPolicy(Component.RestartPolicyEnum.NEVER); + component.setResource( + convertYarnResourceToServiceResource(parameters.getPsResource())); + + // Override global docker image if needed. + if (parameters.getPsDockerImage() != null) { + component.setArtifact( + getDockerArtifact(parameters.getPsDockerImage())); + } + addCommonEnvironments(component, taskType); + generateLaunchCommand(component); + + return component; + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowWorkerComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowWorkerComponent.java new file mode 100644 index 00000000000..74960403704 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TensorFlowWorkerComponent.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; +import java.io.IOException; +import java.util.Objects; +import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE; +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments; +import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact; +import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource; + +/** + * Component implementation for TensorFlow's Worker process. + */ +public class TensorFlowWorkerComponent extends AbstractComponent { + public TensorFlowWorkerComponent(FileSystemOperations fsOperations, + RemoteDirectoryManager remoteDirectoryManager, + RunJobParameters parameters, TaskType taskType, + LaunchCommandFactory launchCommandFactory, + Configuration yarnConfig) { + super(fsOperations, remoteDirectoryManager, parameters, taskType, + yarnConfig, launchCommandFactory); + } + + @Override + public Component createComponent() throws IOException { + Objects.requireNonNull(parameters.getWorkerResource(), + "Worker resource must not be null!"); + if (parameters.getNumWorkers() < 1) { + throw new IllegalArgumentException( + "Number of workers should be at least 1!"); + } + + Component component = new Component(); + component.setName(taskType.getComponentName()); + + if (taskType.equals(TaskType.PRIMARY_WORKER)) { + component.setNumberOfContainers(1L); + component.getConfiguration().setProperty( + CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true"); + } else { + component.setNumberOfContainers( + (long) parameters.getNumWorkers() - 1); + } + + if (parameters.getWorkerDockerImage() != null) { + component.setArtifact( + getDockerArtifact(parameters.getWorkerDockerImage())); + } + + component.setResource(convertYarnResourceToServiceResource( + parameters.getWorkerResource())); + component.setRestartPolicy(Component.RestartPolicyEnum.NEVER); + + addCommonEnvironments(component, taskType); + generateLaunchCommand(component); + + return component; + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/package-info.java new file mode 100644 index 00000000000..10978b717ed --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * This package contains classes to generate + * TensorFlow Native Service components. + */ +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component; \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/package-info.java new file mode 100644 index 00000000000..0c514858433 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * This package contains classes to generate + * TensorFlow-related Native Service runtime artifacts. + */ +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow; \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ClassPathUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ClassPathUtilities.java new file mode 100644 index 00000000000..fc8f6ea6201 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ClassPathUtilities.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import java.io.File; +import java.util.StringTokenizer; + +/** + * Utilities for classpath operations. + */ +public final class ClassPathUtilities { + private ClassPathUtilities() { + throw new UnsupportedOperationException("This class should not be " + + "instantiated!"); + } + + public static File findFileOnClassPath(final String fileName) { + final String classpath = System.getProperty("java.class.path"); + final String pathSeparator = System.getProperty("path.separator"); + final StringTokenizer tokenizer = new StringTokenizer(classpath, + pathSeparator); + + while (tokenizer.hasMoreTokens()) { + final String pathElement = tokenizer.nextToken(); + final File directoryOrJar = new File(pathElement); + final File absoluteDirectoryOrJar = directoryOrJar.getAbsoluteFile(); + if (absoluteDirectoryOrJar.isFile()) { + final File target = + new File(absoluteDirectoryOrJar.getParent(), fileName); + if (target.exists()) { + return target; + } + } else { + final File target = new File(directoryOrJar, fileName); + if (target.exists()) { + return target; + } + } + } + + return null; + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/DockerUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/DockerUtilities.java new file mode 100644 index 00000000000..78cee339676 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/DockerUtilities.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import org.apache.hadoop.yarn.service.api.records.Artifact; + +/** + * Utilities for Docker-related operations. + */ +public final class DockerUtilities { + private DockerUtilities() { + throw new UnsupportedOperationException("This class should not be " + + "instantiated!"); + } + + public static Artifact getDockerArtifact(String dockerImageName) { + return new Artifact().type(Artifact.TypeEnum.DOCKER).id(dockerImageName); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/EnvironmentUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/EnvironmentUtilities.java new file mode 100644 index 00000000000..f4ef7b4e7a6 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/EnvironmentUtilities.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Service; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; + +import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION; + +/** + * Utilities for environment variable related operations + * for {@link Service} objects. + */ +public final class EnvironmentUtilities { + private EnvironmentUtilities() { + throw new UnsupportedOperationException("This class should not be " + + "instantiated!"); + } + + private static final Logger LOG = + LoggerFactory.getLogger(EnvironmentUtilities.class); + + static final String ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME = + "YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS"; + private static final String MOUNTS_DELIM = ","; + private static final String ENV_SEPARATOR = "="; + private static final String ETC_PASSWD_MOUNT_STRING = + "/etc/passwd:/etc/passwd:ro"; + private static final String KERBEROS_CONF_MOUNT_STRING = + "/etc/krb5.conf:/etc/krb5.conf:ro"; + private static final String ENV_VAR_DELIM = ":"; + + /** + * Extracts value from a string representation of an environment variable. + * @param envVar The environment variable in 'key=value' format. + * @return The value of the environment variable + */ + public static String getValueOfEnvironment(String envVar) { + if (envVar == null || !envVar.contains(ENV_SEPARATOR)) { + return ""; + } else { + return envVar.substring(envVar.indexOf(ENV_SEPARATOR) + 1); + } + } + + public static void handleServiceEnvs(Service service, + Configuration yarnConfig, List envVars) { + if (envVars != null) { + for (String envVarPair : envVars) { + String key, value; + if (envVarPair.contains(ENV_SEPARATOR)) { + int idx = envVarPair.indexOf(ENV_SEPARATOR); + key = envVarPair.substring(0, idx); + value = envVarPair.substring(idx + 1); + } else { + LOG.warn("Found environment variable with unusual format: '{}'", + envVarPair); + // No "=" found so use the whole key + key = envVarPair; + value = ""; + } + appendToEnv(service, key, value, ENV_VAR_DELIM); + } + } + appendOtherConfigs(service, yarnConfig); + } + + /** + * Appends other configs like /etc/passwd, /etc/krb5.conf. + * @param service + * @param yarnConfig + */ + private static void appendOtherConfigs(Service service, + Configuration yarnConfig) { + appendToEnv(service, ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME, + ETC_PASSWD_MOUNT_STRING, MOUNTS_DELIM); + + String authentication = yarnConfig.get(HADOOP_SECURITY_AUTHENTICATION); + if (authentication != null && authentication.equals("kerberos")) { + appendToEnv(service, ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME, + KERBEROS_CONF_MOUNT_STRING, MOUNTS_DELIM); + } + } + + static void appendToEnv(Service service, String key, String value, + String delim) { + Map env = service.getConfiguration().getEnv(); + if (!env.containsKey(key)) { + env.put(key, value); + } else { + if (!value.isEmpty()) { + String existingValue = env.get(key); + if (!existingValue.endsWith(delim)) { + env.put(key, existingValue + delim + value); + } else { + env.put(key, existingValue + value); + } + } + } + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/KerberosPrincipalFactory.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/KerberosPrincipalFactory.java new file mode 100644 index 00000000000..a37f37b2a56 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/KerberosPrincipalFactory.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Objects; + +/** + * Simple factory that creates a {@link KerberosPrincipal}. + */ +public final class KerberosPrincipalFactory { + private KerberosPrincipalFactory() { + throw new UnsupportedOperationException("This class should not be " + + "instantiated!"); + } + + private static final Logger LOG = + LoggerFactory.getLogger(KerberosPrincipalFactory.class); + + public static KerberosPrincipal create(FileSystemOperations fsOperations, + RemoteDirectoryManager remoteDirectoryManager, + RunJobParameters parameters) throws IOException { + Objects.requireNonNull(fsOperations, + "FileSystemOperations must not be null!"); + Objects.requireNonNull(remoteDirectoryManager, + "RemoteDirectoryManager must not be null!"); + Objects.requireNonNull(parameters, "Parameters must not be null!"); + + if (StringUtils.isNotBlank(parameters.getKeytab()) && StringUtils + .isNotBlank(parameters.getPrincipal())) { + String keytab = parameters.getKeytab(); + String principal = parameters.getPrincipal(); + if (parameters.isDistributeKeytab()) { + return handleDistributedKeytab(fsOperations, remoteDirectoryManager, + parameters, keytab, principal); + } else { + return handleNormalKeytab(keytab, principal); + } + } + LOG.debug("Principal and keytab was null or empty, " + + "returning null KerberosPrincipal!"); + return null; + } + + private static KerberosPrincipal handleDistributedKeytab( + FileSystemOperations fsOperations, + RemoteDirectoryManager remoteDirectoryManager, + RunJobParameters parameters, String keytab, String principal) + throws IOException { + Path stagingDir = remoteDirectoryManager + .getJobStagingArea(parameters.getName(), true); + Path remoteKeytabPath = + fsOperations.uploadToRemoteFile(stagingDir, keytab); + // Only the owner has read access + fsOperations.setPermission(remoteKeytabPath, + FsPermission.createImmutable((short)Integer.parseInt("400", 8))); + return new KerberosPrincipal() + .keytab(remoteKeytabPath.toString()) + .principalName(principal); + } + + private static KerberosPrincipal handleNormalKeytab(String keytab, + String principal) { + if(!keytab.startsWith("file")) { + keytab = "file://" + keytab; + } + return new KerberosPrincipal() + .keytab(keytab) + .principalName(principal); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/Localizer.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/Localizer.java new file mode 100644 index 00000000000..c86f1a2b15b --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/Localizer.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.service.api.records.ConfigFile; +import org.apache.hadoop.yarn.service.api.records.Service; +import org.apache.hadoop.yarn.submarine.client.cli.param.Localization; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.List; + +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs; +import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.appendToEnv; + +/** + * This class holds all dependencies in order to localize dependencies + * for containers. + */ +public class Localizer { + private static final Logger LOG = LoggerFactory.getLogger(Localizer.class); + + private final FileSystemOperations fsOperations; + private final RemoteDirectoryManager remoteDirectoryManager; + private final RunJobParameters parameters; + + public Localizer(FileSystemOperations fsOperations, + RemoteDirectoryManager remoteDirectoryManager, + RunJobParameters parameters) { + this.fsOperations = fsOperations; + this.remoteDirectoryManager = remoteDirectoryManager; + this.parameters = parameters; + } + + /** + * Localize dependencies for all containers. + * If remoteUri is a local directory, + * we'll zip it, upload to HDFS staging dir HDFS. + * If remoteUri is directory, we'll download it, zip it and upload + * to HDFS. + * If localFilePath is ".", we'll use remoteUri's file/dir name + * */ + public void handleLocalizations(Service service) + throws IOException { + // Handle localizations + Path stagingDir = + remoteDirectoryManager.getJobStagingArea( + parameters.getName(), true); + List localizations = parameters.getLocalizations(); + String remoteUri; + String containerLocalPath; + + // Check to fail fast + for (Localization loc : localizations) { + remoteUri = loc.getRemoteUri(); + Path resourceToLocalize = new Path(remoteUri); + // Check if remoteUri exists + if (remoteDirectoryManager.isRemote(remoteUri)) { + // check if exists + if (!remoteDirectoryManager.existsRemoteFile(resourceToLocalize)) { + throw new FileNotFoundException( + "File " + remoteUri + " doesn't exists."); + } + } else { + // Check if exists + File localFile = new File(remoteUri); + if (!localFile.exists()) { + throw new FileNotFoundException( + "File " + remoteUri + " doesn't exists."); + } + } + // check remote file size + fsOperations.validFileSize(remoteUri); + } + // Start download remote if needed and upload to HDFS + for (Localization loc : localizations) { + remoteUri = loc.getRemoteUri(); + containerLocalPath = loc.getLocalPath(); + String srcFileStr = remoteUri; + ConfigFile.TypeEnum destFileType = ConfigFile.TypeEnum.STATIC; + Path resourceToLocalize = new Path(remoteUri); + boolean needUploadToHDFS = true; + + + // Special handling of remoteUri directory + boolean needDeleteTempFile = false; + if (remoteDirectoryManager.isDir(remoteUri)) { + destFileType = ConfigFile.TypeEnum.ARCHIVE; + srcFileStr = fsOperations.downloadAndZip( + remoteUri, getLastNameFromPath(srcFileStr), true); + } else if (remoteDirectoryManager.isRemote(remoteUri)) { + if (!needHdfs(remoteUri)) { + // Non HDFS remote uri. Non directory, no need to zip + srcFileStr = fsOperations.downloadAndZip( + remoteUri, getLastNameFromPath(srcFileStr), false); + needDeleteTempFile = true; + } else { + // HDFS file, no need to upload + needUploadToHDFS = false; + } + } + + // Upload file to HDFS + if (needUploadToHDFS) { + resourceToLocalize = + fsOperations.uploadToRemoteFile(stagingDir, srcFileStr); + } + if (needDeleteTempFile) { + fsOperations.deleteFiles(srcFileStr); + } + // Remove .zip from zipped dir name + if (destFileType == ConfigFile.TypeEnum.ARCHIVE + && srcFileStr.endsWith(".zip")) { + // Delete local zip file + fsOperations.deleteFiles(srcFileStr); + int suffixIndex = srcFileStr.lastIndexOf('_'); + srcFileStr = srcFileStr.substring(0, suffixIndex); + } + // If provided, use the name of local uri + if (!containerLocalPath.equals(".") + && !containerLocalPath.equals("./")) { + // Change the YARN localized file name to what'll used in container + srcFileStr = getLastNameFromPath(containerLocalPath); + } + String localizedName = getLastNameFromPath(srcFileStr); + LOG.info("The file/dir to be localized is {}", + resourceToLocalize.toString()); + LOG.info("Its localized file name will be {}", localizedName); + service.getConfiguration().getFiles().add(new ConfigFile().srcFile( + resourceToLocalize.toUri().toString()).destFile(localizedName) + .type(destFileType)); + // set mounts + // if mount path is absolute, just use it. + // if relative, no need to mount explicitly + if (containerLocalPath.startsWith("/")) { + String mountStr = getLastNameFromPath(srcFileStr) + ":" + + containerLocalPath + ":" + loc.getMountPermission(); + LOG.info("Add bind-mount string {}", mountStr); + appendToEnv(service, + EnvironmentUtilities.ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME, + mountStr, ","); + } + } + } + + private String getLastNameFromPath(String srcFileStr) { + return new Path(srcFileStr).getName(); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/SubmarineResourceUtils.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/SubmarineResourceUtils.java new file mode 100644 index 00000000000..3d1a237a047 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/SubmarineResourceUtils.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import org.apache.hadoop.yarn.service.api.records.Resource; +import org.apache.hadoop.yarn.service.api.records.ResourceInformation; +import java.util.HashMap; +import java.util.Map; + +/** + * Resource utilities for Submarine. + */ +public final class SubmarineResourceUtils { + private SubmarineResourceUtils() { + throw new UnsupportedOperationException("This class should not be " + + "instantiated!"); + } + + public static Resource convertYarnResourceToServiceResource( + org.apache.hadoop.yarn.api.records.Resource yarnResource) { + Resource serviceResource = new Resource(); + serviceResource.setCpus(yarnResource.getVirtualCores()); + serviceResource.setMemory(String.valueOf(yarnResource.getMemorySize())); + + Map riMap = new HashMap<>(); + for (org.apache.hadoop.yarn.api.records.ResourceInformation ri : + yarnResource.getAllResourcesListCopy()) { + ResourceInformation serviceRi = new ResourceInformation(); + serviceRi.setValue(ri.getValue()); + serviceRi.setUnit(ri.getUnits()); + riMap.put(ri.getName(), serviceRi); + } + serviceResource.setResourceInformations(riMap); + + return serviceResource; + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ZipUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ZipUtilities.java new file mode 100644 index 00000000000..c75f2d33359 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/ZipUtilities.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +/** + * Utilities for zipping directories and adding existing directories to zips. + */ +public final class ZipUtilities { + private ZipUtilities() { + throw new UnsupportedOperationException("This class should not be " + + "instantiated!"); + } + + private static final Logger LOG = LoggerFactory.getLogger(ZipUtilities.class); + + @VisibleForTesting + public static String zipDir(String srcDir, String dstFile) + throws IOException { + FileOutputStream fos = new FileOutputStream(dstFile); + ZipOutputStream zos = new ZipOutputStream(fos); + File srcFile = new File(srcDir); + LOG.info("Compressing directory {}", srcDir); + addDirToZip(zos, srcFile, srcFile); + // close the ZipOutputStream + zos.close(); + LOG.info("Compressed directory {} to file: {}", srcDir, dstFile); + return dstFile; + } + + private static void addDirToZip(ZipOutputStream zos, File srcFile, File base) + throws IOException { + File[] files = srcFile.listFiles(); + if (files == null) { + return; + } + for (File file : files) { + // if it's directory, add recursively + if (file.isDirectory()) { + addDirToZip(zos, file, base); + continue; + } + byte[] buffer = new byte[1024]; + try(FileInputStream fis = new FileInputStream(file)) { + String name = base.toURI().relativize(file.toURI()).getPath(); + LOG.info("Adding file {} to zip", name); + zos.putNextEntry(new ZipEntry(name)); + int length; + while ((length = fis.read(buffer)) > 0) { + zos.write(buffer, 0, length); + } + zos.flush(); + } finally { + zos.closeEntry(); + } + } + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/package-info.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/package-info.java new file mode 100644 index 00000000000..2f60d903cfd --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/main/java/org/apache/hadoop/yarn/submarine/utils/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * This package contains classes utility classes. + */ +package org.apache.hadoop.yarn.submarine.utils; \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/FileUtilitiesForTests.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/FileUtilitiesForTests.java new file mode 100644 index 00000000000..a5161f59f98 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/FileUtilitiesForTests.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine; + +import com.google.common.collect.Lists; +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.fs.Path; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +import static org.junit.Assert.assertTrue; + +/** + * File utilities for tests. + * Provides methods that can create, delete files or directories + * in a temp directory, or any specified directory. + */ +public class FileUtilitiesForTests { + private static final Logger LOG = + LoggerFactory.getLogger(FileUtilitiesForTests.class); + private String tempDir; + private List cleanupFiles; + + public void setup() { + cleanupFiles = Lists.newArrayList(); + tempDir = System.getProperty("java.io.tmpdir"); + } + + public void teardown() throws IOException { + LOG.info("About to clean up files: " + cleanupFiles); + List dirs = Lists.newArrayList(); + for (File cleanupFile : cleanupFiles) { + if (cleanupFile.isDirectory()) { + dirs.add(cleanupFile); + } else { + deleteFile(cleanupFile); + } + } + + for (File dir : dirs) { + deleteFile(dir); + } + } + + public File createFileInTempDir(String filename) throws IOException { + File file = new File(tempDir, new Path(filename).getName()); + createFile(file); + return file; + } + + public File createDirInTempDir(String dirName) { + File file = new File(tempDir, new Path(dirName).getName()); + createDirectory(file); + return file; + } + + public File createFileInDir(Path dir, String filename) throws IOException { + File dirTmp = new File(dir.toUri().getPath()); + if (!dirTmp.exists()) { + createDirectory(dirTmp); + } + File file = + new File(dir.toUri().getPath() + "/" + new Path(filename).getName()); + createFile(file); + return file; + } + + public File createFileInDir(File dir, String filename) throws IOException { + if (!dir.exists()) { + createDirectory(dir); + } + File file = new File(dir, filename); + createFile(file); + return file; + } + + public File createDirectory(Path parent, String dirname) { + File dir = + new File(parent.toUri().getPath() + "/" + new Path(dirname).getName()); + createDirectory(dir); + return dir; + } + + public File createDirectory(File parent, String dirname) { + File dir = + new File(parent.getPath() + "/" + new Path(dirname).getName()); + createDirectory(dir); + return dir; + } + + private void createDirectory(File dir) { + boolean result = dir.mkdir(); + assertTrue("Failed to create directory " + dir.getAbsolutePath(), result); + assertTrue("Directory does not exist: " + dir.getAbsolutePath(), + dir.exists()); + this.cleanupFiles.add(dir); + } + + private void createFile(File file) throws IOException { + boolean result = file.createNewFile(); + assertTrue("Failed to create file " + file.getAbsolutePath(), result); + assertTrue("File does not exist: " + file.getAbsolutePath(), file.exists()); + this.cleanupFiles.add(file); + } + + private static void deleteFile(File file) throws IOException { + if (file.isDirectory()) { + LOG.info("Removing directory: " + file.getAbsolutePath()); + FileUtils.deleteDirectory(file); + } + + if (file.exists()) { + LOG.info("Removing file: " + file.getAbsolutePath()); + boolean result = file.delete(); + assertTrue("Deletion of file " + file.getAbsolutePath() + + " was not successful!", result); + } + } + + public File getTempFileWithName(String filename) { + return new File(tempDir + "/" + new Path(filename).getName()); + } + + public static File getFilename(Path parent, String filename) { + return new File( + parent.toUri().getPath() + "/" + new Path(filename).getName()); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/ParamBuilderForTest.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/ParamBuilderForTest.java new file mode 100644 index 00000000000..8a9b7e0ad4f --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/ParamBuilderForTest.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.client.cli.yarnservice; + +import com.google.common.collect.Lists; + +import java.util.List; + +class ParamBuilderForTest { + private final List params = Lists.newArrayList(); + + static ParamBuilderForTest create() { + return new ParamBuilderForTest(); + } + + ParamBuilderForTest withJobName(String name) { + params.add("--name"); + params.add(name); + return this; + } + + ParamBuilderForTest withDockerImage(String dockerImage) { + params.add("--docker_image"); + params.add(dockerImage); + return this; + } + + ParamBuilderForTest withInputPath(String inputPath) { + params.add("--input_path"); + params.add(inputPath); + return this; + } + + ParamBuilderForTest withCheckpointPath(String checkpointPath) { + params.add("--checkpoint_path"); + params.add(checkpointPath); + return this; + } + + ParamBuilderForTest withNumberOfWorkers(int numWorkers) { + params.add("--num_workers"); + params.add(String.valueOf(numWorkers)); + return this; + } + + ParamBuilderForTest withNumberOfPs(int numPs) { + params.add("--num_ps"); + params.add(String.valueOf(numPs)); + return this; + } + + ParamBuilderForTest withWorkerLaunchCommand(String launchCommand) { + params.add("--worker_launch_cmd"); + params.add(launchCommand); + return this; + } + + ParamBuilderForTest withPsLaunchCommand(String launchCommand) { + params.add("--ps_launch_cmd"); + params.add(launchCommand); + return this; + } + + ParamBuilderForTest withWorkerResources(String workerResources) { + params.add("--worker_resources"); + params.add(workerResources); + return this; + } + + ParamBuilderForTest withPsResources(String psResources) { + params.add("--ps_resources"); + params.add(psResources); + return this; + } + + ParamBuilderForTest withWorkerDockerImage(String dockerImage) { + params.add("--worker_docker_image"); + params.add(dockerImage); + return this; + } + + ParamBuilderForTest withPsDockerImage(String dockerImage) { + params.add("--ps_docker_image"); + params.add(dockerImage); + return this; + } + + ParamBuilderForTest withVerbose() { + params.add("--verbose"); + return this; + } + + ParamBuilderForTest withTensorboard() { + params.add("--tensorboard"); + return this; + } + + ParamBuilderForTest withTensorboardResources(String resources) { + params.add("--tensorboard_resources"); + params.add(resources); + return this; + } + + ParamBuilderForTest withTensorboardDockerImage(String dockerImage) { + params.add("--tensorboard_docker_image"); + params.add(dockerImage); + return this; + } + + ParamBuilderForTest withQuickLink(String quickLink) { + params.add("--quicklink"); + params.add(quickLink); + return this; + } + + ParamBuilderForTest withLocalization(String remoteUrl, String localUrl) { + params.add("--localization"); + params.add(remoteUrl + ":" + localUrl); + return this; + } + + String[] build() { + return params.toArray(new String[0]); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCli.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCli.java index ee6b5c1078d..2a568cbdca0 100644 --- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCli.java +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCli.java @@ -20,26 +20,23 @@ import com.google.common.collect.ImmutableMap; import org.apache.hadoop.fs.FileUtil; -import org.apache.hadoop.fs.Path; import org.apache.hadoop.yarn.api.records.Resource; -import org.apache.hadoop.yarn.client.api.AppAdminClient; import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.service.api.records.Component; -import org.apache.hadoop.yarn.service.api.records.ConfigFile; import org.apache.hadoop.yarn.service.api.records.Service; import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli; import org.apache.hadoop.yarn.submarine.common.MockClientContext; import org.apache.hadoop.yarn.submarine.common.api.TaskType; -import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration; import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; -import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter; import org.apache.hadoop.yarn.submarine.runtimes.common.StorageKeyConstants; import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceJobSubmitter; -import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent; +import org.apache.hadoop.yarn.submarine.utils.ZipUtilities; import org.apache.hadoop.yarn.util.resource.Resources; -import org.junit.Assert; +import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,29 +45,41 @@ import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Paths; -import java.util.List; import java.util.Map; -import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_CHECKPOINT_PATH; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_DOCKER_IMAGE; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_INPUT_PATH; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_JOB_NAME; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_PS_DOCKER_IMAGE; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_PS_LAUNCH_CMD; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_PS_RESOURCES; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_TENSORBOARD_DOCKER_IMAGE; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_TENSORBOARD_RESOURCES; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_WORKER_DOCKER_IMAGE; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_WORKER_LAUNCH_CMD; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.DEFAULT_WORKER_RESOURCES; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +/** + * Class to test YarnService with the Run job CLI action. + */ public class TestYarnServiceRunJobCli { + private TestYarnServiceRunJobCliCommons testCommons = + new TestYarnServiceRunJobCliCommons(); + @Before public void before() throws IOException, YarnException { - SubmarineLogs.verboseOff(); - AppAdminClient serviceClient = mock(AppAdminClient.class); - when(serviceClient.actionLaunch(any(String.class), any(String.class), - any(Long.class), any(String.class))).thenReturn(EXIT_SUCCESS); - when(serviceClient.getStatusString(any(String.class))).thenReturn( - "{\"id\": \"application_1234_1\"}"); - YarnServiceUtils.setStubServiceClient(serviceClient); + testCommons.setup(); + } + + @After + public void cleanup() throws IOException { + testCommons.teardown(); } @Test @@ -81,53 +90,50 @@ public void testPrintHelp() { runJobCli.printUsages(); } - private Service getServiceSpecFromJobSubmitter(JobSubmitter jobSubmitter) { - return ((YarnServiceJobSubmitter) jobSubmitter).getServiceSpec(); + private ServiceWrapper getServiceWrapperFromJobSubmitter( + JobSubmitter jobSubmitter) { + return ((YarnServiceJobSubmitter) jobSubmitter).getServiceWrapper(); } - private void commonVerifyDistributedTrainingSpec(Service serviceSpec) - throws Exception { - Assert.assertTrue( - serviceSpec.getComponent(TaskType.WORKER.getComponentName()) != null); - Assert.assertTrue( - serviceSpec.getComponent(TaskType.PRIMARY_WORKER.getComponentName()) - != null); - Assert.assertTrue( - serviceSpec.getComponent(TaskType.PS.getComponentName()) != null); + private void commonVerifyDistributedTrainingSpec(Service serviceSpec) { + assertNotNull(serviceSpec.getComponent(TaskType.WORKER.getComponentName())); + assertNotNull( + serviceSpec.getComponent(TaskType.PRIMARY_WORKER.getComponentName())); + assertNotNull(serviceSpec.getComponent(TaskType.PS.getComponentName())); Component primaryWorkerComp = serviceSpec.getComponent( TaskType.PRIMARY_WORKER.getComponentName()); - Assert.assertEquals(2048, primaryWorkerComp.getResource().calcMemoryMB()); - Assert.assertEquals(2, + assertEquals(2048, primaryWorkerComp.getResource().calcMemoryMB()); + assertEquals(2, primaryWorkerComp.getResource().getCpus().intValue()); Component workerComp = serviceSpec.getComponent( TaskType.WORKER.getComponentName()); - Assert.assertEquals(2048, workerComp.getResource().calcMemoryMB()); - Assert.assertEquals(2, workerComp.getResource().getCpus().intValue()); + assertEquals(2048, workerComp.getResource().calcMemoryMB()); + assertEquals(2, workerComp.getResource().getCpus().intValue()); Component psComp = serviceSpec.getComponent(TaskType.PS.getComponentName()); - Assert.assertEquals(4096, psComp.getResource().calcMemoryMB()); - Assert.assertEquals(4, psComp.getResource().getCpus().intValue()); + assertEquals(4096, psComp.getResource().calcMemoryMB()); + assertEquals(4, psComp.getResource().getCpus().intValue()); - Assert.assertEquals("worker.image", workerComp.getArtifact().getId()); - Assert.assertEquals("ps.image", psComp.getArtifact().getId()); + assertEquals(DEFAULT_WORKER_DOCKER_IMAGE, workerComp.getArtifact().getId()); + assertEquals(DEFAULT_PS_DOCKER_IMAGE, psComp.getArtifact().getId()); - Assert.assertTrue(SubmarineLogs.isVerbose()); + assertTrue(SubmarineLogs.isVerbose()); } private void verifyQuicklink(Service serviceSpec, Map expectedQuicklinks) { Map actualQuicklinks = serviceSpec.getQuicklinks(); if (actualQuicklinks == null || actualQuicklinks.isEmpty()) { - Assert.assertTrue( + assertTrue( expectedQuicklinks == null || expectedQuicklinks.isEmpty()); return; } - Assert.assertEquals(expectedQuicklinks.size(), actualQuicklinks.size()); + assertEquals(expectedQuicklinks.size(), actualQuicklinks.size()); for (Map.Entry expectedEntry : expectedQuicklinks .entrySet()) { - Assert.assertTrue(actualQuicklinks.containsKey(expectedEntry.getKey())); + assertTrue(actualQuicklinks.containsKey(expectedEntry.getKey())); // $USER could be changed in different environment. so replace $USER by // "user" @@ -137,7 +143,7 @@ private void verifyQuicklink(Service serviceSpec, String userName = System.getProperty("user.name"); actualValue = actualValue.replaceAll(userName, "username"); - Assert.assertEquals(expectedValue, actualValue); + assertEquals(expectedValue, actualValue); } } @@ -146,19 +152,27 @@ public void testBasicRunJobForDistributedTraining() throws Exception { MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose"}); - Service serviceSpec = getServiceSpecFromJobSubmitter( + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(3) + .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE) + .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD) + .withWorkerResources(DEFAULT_WORKER_RESOURCES) + .withNumberOfPs(2) + .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE) + .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD) + .withPsResources(DEFAULT_PS_RESOURCES) + .withVerbose() + .build(); + runJobCli.run(params); + Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter( runJobCli.getJobSubmitter()); - Assert.assertEquals(3, serviceSpec.getComponents().size()); + assertEquals(3, serviceSpec.getComponents().size()); commonVerifyDistributedTrainingSpec(serviceSpec); @@ -171,28 +185,37 @@ public void testBasicRunJobForDistributedTrainingWithTensorboard() MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--tensorboard", "--ps_launch_cmd", "python run-ps.py", - "--verbose"}); - Service serviceSpec = getServiceSpecFromJobSubmitter( + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(3) + .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE) + .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD) + .withWorkerResources(DEFAULT_WORKER_RESOURCES) + .withNumberOfPs(2) + .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE) + .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD) + .withPsResources(DEFAULT_PS_RESOURCES) + .withVerbose() + .withTensorboard() + .build(); + runJobCli.run(params); + ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter( runJobCli.getJobSubmitter()); - Assert.assertEquals(4, serviceSpec.getComponents().size()); + Service serviceSpec = serviceWrapper.getService(); + assertEquals(4, serviceSpec.getComponents().size()); commonVerifyDistributedTrainingSpec(serviceSpec); - verifyTensorboardComponent(runJobCli, serviceSpec, + verifyTensorboardComponent(runJobCli, serviceWrapper, Resources.createResource(4096, 1)); verifyQuicklink(serviceSpec, ImmutableMap - .of(YarnServiceJobSubmitter.TENSORBOARD_QUICKLINK_LABEL, + .of(TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL, "http://tensorboard-0.my-job.username.null:6006")); } @@ -201,17 +224,23 @@ public void testBasicRunJobForSingleNodeTraining() throws Exception { MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "1", "--worker_launch_cmd", "python run-job.py", - "--worker_resources", "memory=2G,vcores=2", "--verbose"}); + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(1) + .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD) + .withWorkerResources(DEFAULT_TENSORBOARD_RESOURCES) + .withVerbose() + .build(); + runJobCli.run(params); - Service serviceSpec = getServiceSpecFromJobSubmitter( + Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter( runJobCli.getJobSubmitter()); - Assert.assertEquals(1, serviceSpec.getComponents().size()); + assertEquals(1, serviceSpec.getComponents().size()); commonTestSingleNodeTraining(serviceSpec); } @@ -221,41 +250,53 @@ public void testTensorboardOnlyService() throws Exception { MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "0", "--tensorboard", "--verbose"}); + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(0) + .withTensorboard() + .withVerbose() + .build(); + runJobCli.run(params); - Service serviceSpec = getServiceSpecFromJobSubmitter( + ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter( runJobCli.getJobSubmitter()); - Assert.assertEquals(1, serviceSpec.getComponents().size()); + assertEquals(1, serviceWrapper.getService().getComponents().size()); - verifyTensorboardComponent(runJobCli, serviceSpec, + verifyTensorboardComponent(runJobCli, serviceWrapper, Resources.createResource(4096, 1)); } @Test - public void testTensorboardOnlyServiceWithCustomizedDockerImageAndResourceCkptPath() + public void testTensorboardOnlyServiceWithCustomDockerImageAndCheckpointPath() throws Exception { MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "0", "--tensorboard", "--verbose", - "--tensorboard_resources", "memory=2G,vcores=2", - "--tensorboard_docker_image", "tb_docker_image:001"}); + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(0) + .withTensorboard() + .withTensorboardResources(DEFAULT_TENSORBOARD_RESOURCES) + .withTensorboardDockerImage(DEFAULT_TENSORBOARD_DOCKER_IMAGE) + .withVerbose() + .build(); + runJobCli.run(params); - Service serviceSpec = getServiceSpecFromJobSubmitter( + ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter( runJobCli.getJobSubmitter()); - Assert.assertEquals(1, serviceSpec.getComponents().size()); + assertEquals(1, serviceWrapper.getService().getComponents().size()); - verifyTensorboardComponent(runJobCli, serviceSpec, + verifyTensorboardComponent(runJobCli, serviceWrapper, Resources.createResource(2048, 2)); } @@ -265,94 +306,92 @@ public void testTensorboardOnlyServiceWithCustomizedDockerImageAndResource() MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--num_workers", "0", "--tensorboard", "--verbose", - "--tensorboard_resources", "memory=2G,vcores=2", - "--tensorboard_docker_image", "tb_docker_image:001"}); + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withNumberOfWorkers(0) + .withTensorboard() + .withTensorboardResources(DEFAULT_TENSORBOARD_RESOURCES) + .withTensorboardDockerImage(DEFAULT_TENSORBOARD_DOCKER_IMAGE) + .withVerbose() + .build(); + runJobCli.run(params); - Service serviceSpec = getServiceSpecFromJobSubmitter( + ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter( runJobCli.getJobSubmitter()); - Assert.assertEquals(1, serviceSpec.getComponents().size()); + assertEquals(1, serviceWrapper.getService().getComponents().size()); - verifyTensorboardComponent(runJobCli, serviceSpec, + verifyTensorboardComponent(runJobCli, serviceWrapper, Resources.createResource(2048, 2)); - verifyQuicklink(serviceSpec, ImmutableMap - .of(YarnServiceJobSubmitter.TENSORBOARD_QUICKLINK_LABEL, + verifyQuicklink(serviceWrapper.getService(), ImmutableMap + .of(TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL, "http://tensorboard-0.my-job.username.null:6006")); } - private void commonTestSingleNodeTraining(Service serviceSpec) - throws Exception { - Assert.assertTrue( - serviceSpec.getComponent(TaskType.PRIMARY_WORKER.getComponentName()) - != null); + private void commonTestSingleNodeTraining(Service serviceSpec) { + assertNotNull( + serviceSpec.getComponent(TaskType.PRIMARY_WORKER.getComponentName())); Component primaryWorkerComp = serviceSpec.getComponent( TaskType.PRIMARY_WORKER.getComponentName()); - Assert.assertEquals(2048, primaryWorkerComp.getResource().calcMemoryMB()); - Assert.assertEquals(2, + assertEquals(2048, primaryWorkerComp.getResource().calcMemoryMB()); + assertEquals(2, primaryWorkerComp.getResource().getCpus().intValue()); - Assert.assertTrue(SubmarineLogs.isVerbose()); + assertTrue(SubmarineLogs.isVerbose()); } private void verifyTensorboardComponent(RunJobCli runJobCli, - Service serviceSpec, Resource resource) throws Exception { - Assert.assertTrue( - serviceSpec.getComponent(TaskType.TENSORBOARD.getComponentName()) - != null); + ServiceWrapper serviceWrapper, Resource resource) throws Exception { + Service serviceSpec = serviceWrapper.getService(); + assertNotNull( + serviceSpec.getComponent(TaskType.TENSORBOARD.getComponentName())); Component tensorboardComp = serviceSpec.getComponent( TaskType.TENSORBOARD.getComponentName()); - Assert.assertEquals(1, tensorboardComp.getNumberOfContainers().intValue()); - Assert.assertEquals(resource.getMemorySize(), + assertEquals(1, tensorboardComp.getNumberOfContainers().intValue()); + assertEquals(resource.getMemorySize(), tensorboardComp.getResource().calcMemoryMB()); - Assert.assertEquals(resource.getVirtualCores(), + assertEquals(resource.getVirtualCores(), tensorboardComp.getResource().getCpus().intValue()); - Assert.assertEquals("./run-TENSORBOARD.sh", + assertEquals("./run-TENSORBOARD.sh", tensorboardComp.getLaunchCommand()); // Check docker image if (runJobCli.getRunJobParameters().getTensorboardDockerImage() != null) { - Assert.assertEquals( + assertEquals( runJobCli.getRunJobParameters().getTensorboardDockerImage(), tensorboardComp.getArtifact().getId()); } else { - Assert.assertNull(tensorboardComp.getArtifact()); + assertNull(tensorboardComp.getArtifact()); } - YarnServiceJobSubmitter yarnServiceJobSubmitter = - (YarnServiceJobSubmitter) runJobCli.getJobSubmitter(); - String expectedLaunchScript = "#!/bin/bash\n" + "echo \"CLASSPATH:$CLASSPATH\"\n" + "echo \"HADOOP_CONF_DIR:$HADOOP_CONF_DIR\"\n" - + "echo \"HADOOP_TOKEN_FILE_LOCATION:$HADOOP_TOKEN_FILE_LOCATION\"\n" + + "echo \"HADOOP_TOKEN_FILE_LOCATION:" + + "$HADOOP_TOKEN_FILE_LOCATION\"\n" + "echo \"JAVA_HOME:$JAVA_HOME\"\n" + "echo \"LD_LIBRARY_PATH:$LD_LIBRARY_PATH\"\n" + "echo \"HADOOP_HDFS_HOME:$HADOOP_HDFS_HOME\"\n" + "export LC_ALL=C && tensorboard --logdir=" + runJobCli .getRunJobParameters().getCheckpointPath() + "\n"; - verifyLaunchScriptForComponet(yarnServiceJobSubmitter, serviceSpec, + verifyLaunchScriptForComponent(serviceWrapper, TaskType.TENSORBOARD, expectedLaunchScript); } - private void verifyLaunchScriptForComponet( - YarnServiceJobSubmitter yarnServiceJobSubmitter, Service serviceSpec, + private void verifyLaunchScriptForComponent(ServiceWrapper serviceWrapper, TaskType taskType, String expectedLaunchScriptContent) throws Exception { - Map componentToLocalLaunchScriptMap = - yarnServiceJobSubmitter.getComponentToLocalLaunchScriptPath(); - String path = componentToLocalLaunchScriptMap.get( - taskType.getComponentName()); + String path = serviceWrapper + .getLocalLaunchCommandPathForComponent(taskType.getComponentName()); byte[] encoded = Files.readAllBytes(Paths.get(path)); String scriptContent = new String(encoded, Charset.defaultCharset()); - Assert.assertEquals(expectedLaunchScriptContent, scriptContent); + assertEquals(expectedLaunchScriptContent, scriptContent); } @Test @@ -361,21 +400,28 @@ public void testBasicRunJobForSingleNodeTrainingWithTensorboard() MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "1", "--worker_launch_cmd", "python run-job.py", - "--worker_resources", "memory=2G,vcores=2", "--tensorboard", - "--verbose"}); - Service serviceSpec = getServiceSpecFromJobSubmitter( + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(1) + .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD) + .withWorkerResources(DEFAULT_TENSORBOARD_RESOURCES) + .withTensorboard() + .withVerbose() + .build(); + runJobCli.run(params); + ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter( runJobCli.getJobSubmitter()); + Service serviceSpec = serviceWrapper.getService(); - Assert.assertEquals(2, serviceSpec.getComponents().size()); + assertEquals(2, serviceSpec.getComponents().size()); commonTestSingleNodeTraining(serviceSpec); - verifyTensorboardComponent(runJobCli, serviceSpec, + verifyTensorboardComponent(runJobCli, serviceWrapper, Resources.createResource(4096, 1)); } @@ -385,20 +431,27 @@ public void testBasicRunJobForSingleNodeTrainingWithGeneratedCheckpoint() MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--num_workers", "1", - "--worker_launch_cmd", "python run-job.py", "--worker_resources", - "memory=2G,vcores=2", "--tensorboard", "--verbose"}); - Service serviceSpec = getServiceSpecFromJobSubmitter( + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withNumberOfWorkers(1) + .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD) + .withWorkerResources(DEFAULT_TENSORBOARD_RESOURCES) + .withTensorboard() + .withVerbose() + .build(); + runJobCli.run(params); + ServiceWrapper serviceWrapper = getServiceWrapperFromJobSubmitter( runJobCli.getJobSubmitter()); + Service serviceSpec = serviceWrapper.getService(); - Assert.assertEquals(2, serviceSpec.getComponents().size()); + assertEquals(2, serviceSpec.getComponents().size()); commonTestSingleNodeTraining(serviceSpec); - verifyTensorboardComponent(runJobCli, serviceSpec, + verifyTensorboardComponent(runJobCli, serviceWrapper, Resources.createResource(4096, 1)); } @@ -407,20 +460,26 @@ public void testParameterStorageForTrainingJob() throws Exception { MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "1", "--worker_launch_cmd", "python run-job.py", - "--worker_resources", "memory=2G,vcores=2", "--tensorboard", "true", - "--verbose"}); + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(1) + .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD) + .withWorkerResources(DEFAULT_TENSORBOARD_RESOURCES) + .withTensorboard() + .withVerbose() + .build(); + runJobCli.run(params); SubmarineStorage storage = mockClientContext.getRuntimeFactory().getSubmarineStorage(); - Map jobInfo = storage.getJobInfoByName("my-job"); - Assert.assertTrue(jobInfo.size() > 0); - Assert.assertEquals(jobInfo.get(StorageKeyConstants.INPUT_PATH), - "s3://input"); + Map jobInfo = storage.getJobInfoByName(DEFAULT_JOB_NAME); + assertTrue(jobInfo.size() > 0); + assertEquals(jobInfo.get(StorageKeyConstants.INPUT_PATH), + DEFAULT_INPUT_PATH); } @Test @@ -428,21 +487,29 @@ public void testAddQuicklinksWithoutTensorboard() throws Exception { MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", "--quicklink", - "AAA=http://master-0:8321", "--quicklink", - "BBB=http://worker-0:1234"}); - Service serviceSpec = getServiceSpecFromJobSubmitter( + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(3) + .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE) + .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD) + .withWorkerResources(DEFAULT_WORKER_RESOURCES) + .withNumberOfPs(2) + .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE) + .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD) + .withPsResources(DEFAULT_PS_RESOURCES) + .withQuickLink("AAA=http://master-0:8321") + .withQuickLink("BBB=http://worker-0:1234") + .withVerbose() + .build(); + runJobCli.run(params); + Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter( runJobCli.getJobSubmitter()); - Assert.assertEquals(3, serviceSpec.getComponents().size()); + assertEquals(3, serviceSpec.getComponents().size()); commonVerifyDistributedTrainingSpec(serviceSpec); @@ -456,712 +523,41 @@ public void testAddQuicklinksWithTensorboard() throws Exception { MockClientContext mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); + assertFalse(SubmarineLogs.isVerbose()); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", "--quicklink", - "AAA=http://master-0:8321", "--quicklink", - "BBB=http://worker-0:1234", "--tensorboard"}); - Service serviceSpec = getServiceSpecFromJobSubmitter( + String[] params = ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(3) + .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE) + .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD) + .withWorkerResources(DEFAULT_WORKER_RESOURCES) + .withNumberOfPs(2) + .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE) + .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD) + .withPsResources(DEFAULT_PS_RESOURCES) + .withQuickLink("AAA=http://master-0:8321") + .withQuickLink("BBB=http://worker-0:1234") + .withTensorboard() + .withVerbose() + .build(); + + runJobCli.run(params); + Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter( runJobCli.getJobSubmitter()); - Assert.assertEquals(4, serviceSpec.getComponents().size()); + assertEquals(4, serviceSpec.getComponents().size()); commonVerifyDistributedTrainingSpec(serviceSpec); verifyQuicklink(serviceSpec, ImmutableMap .of("AAA", "http://master-0.my-job.username.null:8321", "BBB", "http://worker-0.my-job.username.null:1234", - YarnServiceJobSubmitter.TENSORBOARD_QUICKLINK_LABEL, + TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL, "http://tensorboard-0.my-job.username.null:6006")); } - /** - * Basic test. - * In one hand, create local temp file/dir for hdfs URI in - * local staging dir. - * In the other hand, use MockRemoteDirectoryManager mock - * implementation when check FileStatus or exists of HDFS file/dir - * --localization hdfs:///user/yarn/script1.py:. - * --localization /temp/script2.py:./ - * --localization /temp/script2.py:/opt/script.py - */ - @Test - public void testRunJobWithBasicLocalization() throws Exception { - String remoteUrl = "hdfs:///user/yarn/script1.py"; - String containerLocal1 = "."; - String localUrl = "/temp/script2.py"; - String containerLocal2 = "./"; - String containerLocal3 = "/opt/script.py"; - String fakeLocalDir = System.getProperty("java.io.tmpdir"); - // create local file, we need to put it under local temp dir - File localFile1 = new File(fakeLocalDir, - new Path(localUrl).getName()); - localFile1.createNewFile(); - - - MockClientContext mockClientContext = - YarnServiceCliTestUtils.getMockClientContext(); - RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); - - RemoteDirectoryManager spyRdm = - spy(mockClientContext.getRemoteDirectoryManager()); - mockClientContext.setRemoteDirectoryMgr(spyRdm); - - // create remote file in local staging dir to simulate HDFS - Path stagingDir = mockClientContext.getRemoteDirectoryManager() - .getJobStagingArea("my-job", true); - File remoteFile1 = new File(stagingDir.toUri().getPath() - + "/" + new Path(remoteUrl).getName()); - remoteFile1.createNewFile(); - - Assert.assertTrue(localFile1.exists()); - Assert.assertTrue(remoteFile1.exists()); - - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", - "--localization", - remoteUrl + ":" + containerLocal1, - "--localization", - localFile1.getAbsolutePath() + ":" + containerLocal2, - "--localization", - localFile1.getAbsolutePath() + ":" + containerLocal3}); - Service serviceSpec = getServiceSpecFromJobSubmitter( - runJobCli.getJobSubmitter()); - Assert.assertEquals(3, serviceSpec.getComponents().size()); - - // No remote dir and hdfs file exists. Ensure download 0 times - verify(spyRdm, times(0)).copyRemoteToLocal( - anyString(), anyString()); - // Ensure local original files are not deleted - Assert.assertTrue(localFile1.exists()); - - List files = serviceSpec.getConfiguration().getFiles(); - Assert.assertEquals(3, files.size()); - ConfigFile file = files.get(0); - Assert.assertEquals(ConfigFile.TypeEnum.STATIC, file.getType()); - String expectedSrcLocalization = remoteUrl; - Assert.assertEquals(expectedSrcLocalization, - file.getSrcFile()); - String expectedDstFileName = new Path(remoteUrl).getName(); - Assert.assertEquals(expectedDstFileName, file.getDestFile()); - - file = files.get(1); - Assert.assertEquals(ConfigFile.TypeEnum.STATIC, file.getType()); - expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(localUrl).getName(); - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - expectedDstFileName = new Path(localUrl).getName(); - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - - file = files.get(2); - Assert.assertEquals(ConfigFile.TypeEnum.STATIC, file.getType()); - expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(localUrl).getName(); - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - expectedDstFileName = new Path(localUrl).getName(); - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - - // Ensure env value is correct - String env = serviceSpec.getConfiguration().getEnv() - .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS"); - String expectedMounts = new Path(containerLocal3).getName() - + ":" + containerLocal3 + ":rw"; - Assert.assertTrue(env.contains(expectedMounts)); - - remoteFile1.delete(); - localFile1.delete(); - } - - /** - * Non HDFS remote URI test. - * --localization https://a/b/1.patch:. - * --localization s3a://a/dir:/opt/mys3dir - */ - @Test - public void testRunJobWithNonHDFSRemoteLocalization() throws Exception { - String remoteUri1 = "https://a/b/1.patch"; - String containerLocal1 = "."; - String remoteUri2 = "s3a://a/s3dir"; - String containerLocal2 = "/opt/mys3dir"; - - MockClientContext mockClientContext = - YarnServiceCliTestUtils.getMockClientContext(); - RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); - - RemoteDirectoryManager spyRdm = - spy(mockClientContext.getRemoteDirectoryManager()); - mockClientContext.setRemoteDirectoryMgr(spyRdm); - - // create remote file in local staging dir to simulate HDFS - Path stagingDir = mockClientContext.getRemoteDirectoryManager() - .getJobStagingArea("my-job", true); - File remoteFile1 = new File(stagingDir.toUri().getPath() - + "/" + new Path(remoteUri1).getName()); - remoteFile1.createNewFile(); - - File remoteDir1 = new File(stagingDir.toUri().getPath() - + "/" + new Path(remoteUri2).getName()); - remoteDir1.mkdir(); - File remoteDir1File1 = new File(remoteDir1, "afile"); - remoteDir1File1.createNewFile(); - - Assert.assertTrue(remoteFile1.exists()); - Assert.assertTrue(remoteDir1.exists()); - Assert.assertTrue(remoteDir1File1.exists()); - - String suffix1 = "_" + remoteDir1.lastModified() - + "-" + mockClientContext.getRemoteDirectoryManager() - .getRemoteFileSize(remoteUri2); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", - "--localization", - remoteUri1 + ":" + containerLocal1, - "--localization", - remoteUri2 + ":" + containerLocal2}); - Service serviceSpec = getServiceSpecFromJobSubmitter( - runJobCli.getJobSubmitter()); - Assert.assertEquals(3, serviceSpec.getComponents().size()); - - // Ensure download remote dir 2 times - verify(spyRdm, times(2)).copyRemoteToLocal( - anyString(), anyString()); - - // Ensure downloaded temp files are deleted - Assert.assertFalse(new File(System.getProperty("java.io.tmpdir") - + "/" + new Path(remoteUri1).getName()).exists()); - Assert.assertFalse(new File(System.getProperty("java.io.tmpdir") - + "/" + new Path(remoteUri2).getName()).exists()); - - // Ensure zip file are deleted - Assert.assertFalse(new File(System.getProperty("java.io.tmpdir") - + "/" + new Path(remoteUri2).getName() - + "_" + suffix1 + ".zip").exists()); - - List files = serviceSpec.getConfiguration().getFiles(); - Assert.assertEquals(2, files.size()); - ConfigFile file = files.get(0); - Assert.assertEquals(ConfigFile.TypeEnum.STATIC, file.getType()); - String expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(remoteUri1).getName(); - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - String expectedDstFileName = new Path(remoteUri1).getName(); - Assert.assertEquals(expectedDstFileName, file.getDestFile()); - - file = files.get(1); - Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType()); - expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(remoteUri2).getName() + suffix1 + ".zip"; - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - - expectedDstFileName = new Path(containerLocal2).getName(); - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - - // Ensure env value is correct - String env = serviceSpec.getConfiguration().getEnv() - .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS"); - String expectedMounts = new Path(remoteUri2).getName() - + ":" + containerLocal2 + ":rw"; - Assert.assertTrue(env.contains(expectedMounts)); - - remoteDir1File1.delete(); - remoteFile1.delete(); - remoteDir1.delete(); - } - - /** - * Test HDFS dir localization. - * --localization hdfs:///user/yarn/mydir:./mydir1 - * --localization hdfs:///user/yarn/mydir2:/opt/dir2:rw - * --localization hdfs:///user/yarn/mydir:. - * --localization hdfs:///user/yarn/mydir2:./ - */ - @Test - public void testRunJobWithHdfsDirLocalization() throws Exception { - String remoteUrl = "hdfs:///user/yarn/mydir"; - String containerPath = "./mydir1"; - String remoteUrl2 = "hdfs:///user/yarn/mydir2"; - String containPath2 = "/opt/dir2"; - String containerPath3 = "."; - String containerPath4 = "./"; - MockClientContext mockClientContext = - YarnServiceCliTestUtils.getMockClientContext(); - RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); - - RemoteDirectoryManager spyRdm = - spy(mockClientContext.getRemoteDirectoryManager()); - mockClientContext.setRemoteDirectoryMgr(spyRdm); - // create remote file in local staging dir to simulate HDFS - Path stagingDir = mockClientContext.getRemoteDirectoryManager() - .getJobStagingArea("my-job", true); - File remoteDir1 = new File(stagingDir.toUri().getPath().toString() - + "/" + new Path(remoteUrl).getName()); - remoteDir1.mkdir(); - File remoteFile1 = new File(remoteDir1.getAbsolutePath() + "/1.py"); - File remoteFile2 = new File(remoteDir1.getAbsolutePath() + "/2.py"); - remoteFile1.createNewFile(); - remoteFile2.createNewFile(); - - File remoteDir2 = new File(stagingDir.toUri().getPath().toString() - + "/" + new Path(remoteUrl2).getName()); - remoteDir2.mkdir(); - File remoteFile3 = new File(remoteDir1.getAbsolutePath() + "/3.py"); - File remoteFile4 = new File(remoteDir1.getAbsolutePath() + "/4.py"); - remoteFile3.createNewFile(); - remoteFile4.createNewFile(); - - Assert.assertTrue(remoteDir1.exists()); - Assert.assertTrue(remoteDir2.exists()); - - String suffix1 = "_" + remoteDir1.lastModified() - + "-" + mockClientContext.getRemoteDirectoryManager() - .getRemoteFileSize(remoteUrl); - String suffix2 = "_" + remoteDir2.lastModified() - + "-" + mockClientContext.getRemoteDirectoryManager() - .getRemoteFileSize(remoteUrl2); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", - "--localization", - remoteUrl + ":" + containerPath, - "--localization", - remoteUrl2 + ":" + containPath2 + ":rw", - "--localization", - remoteUrl + ":" + containerPath3, - "--localization", - remoteUrl2 + ":" + containerPath4}); - Service serviceSpec = getServiceSpecFromJobSubmitter( - runJobCli.getJobSubmitter()); - Assert.assertEquals(3, serviceSpec.getComponents().size()); - - // Ensure download remote dir 4 times - verify(spyRdm, times(4)).copyRemoteToLocal( - anyString(), anyString()); - - // Ensure downloaded temp files are deleted - Assert.assertFalse(new File(System.getProperty("java.io.tmpdir") - + "/" + new Path(remoteUrl).getName()).exists()); - Assert.assertFalse(new File(System.getProperty("java.io.tmpdir") - + "/" + new Path(remoteUrl2).getName()).exists()); - // Ensure zip file are deleted - Assert.assertFalse(new File(System.getProperty("java.io.tmpdir") - + "/" + new Path(remoteUrl).getName() - + suffix1 + ".zip").exists()); - Assert.assertFalse(new File(System.getProperty("java.io.tmpdir") - + "/" + new Path(remoteUrl2).getName() - + suffix2 + ".zip").exists()); - - // Ensure files will be localized - List files = serviceSpec.getConfiguration().getFiles(); - Assert.assertEquals(4, files.size()); - ConfigFile file = files.get(0); - // The hdfs dir should be download and compress and let YARN to uncompress - Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType()); - String expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(remoteUrl).getName() + suffix1 + ".zip"; - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - - // Relative path in container, but not "." or "./". Use its own name - String expectedDstFileName = new Path(containerPath).getName(); - Assert.assertEquals(expectedDstFileName, file.getDestFile()); - - file = files.get(1); - Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType()); - expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(remoteUrl2).getName() + suffix2 + ".zip"; - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - - expectedDstFileName = new Path(containPath2).getName(); - Assert.assertEquals(expectedDstFileName, file.getDestFile()); - - file = files.get(2); - Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType()); - expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(remoteUrl).getName() + suffix1 + ".zip"; - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - // Relative path in container ".", use remote path name - expectedDstFileName = new Path(remoteUrl).getName(); - Assert.assertEquals(expectedDstFileName, file.getDestFile()); - - file = files.get(3); - Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType()); - expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(remoteUrl2).getName() + suffix2 + ".zip"; - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - // Relative path in container "./", use remote path name - expectedDstFileName = new Path(remoteUrl2).getName(); - Assert.assertEquals(expectedDstFileName, file.getDestFile()); - - // Ensure mounts env value is correct. Add one mount string - String env = serviceSpec.getConfiguration().getEnv() - .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS"); - - String expectedMounts = - new Path(containPath2).getName() + ":" + containPath2 + ":rw"; - Assert.assertTrue(env.contains(expectedMounts)); - - remoteFile1.delete(); - remoteFile2.delete(); - remoteFile3.delete(); - remoteFile4.delete(); - remoteDir1.delete(); - remoteDir2.delete(); - } - - /** - * Test if file/dir to be localized whose size exceeds limit. - * Max 10MB in configuration, mock remote will - * always return file size 100MB. - * This configuration will fail the job which has remoteUri - * But don't impact local dir/file - * - * --localization https://a/b/1.patch:. - * --localization s3a://a/dir:/opt/mys3dir - * --localization /temp/script2.py:./ - */ - @Test - public void testRunJobRemoteUriExceedLocalizationSize() throws Exception { - String remoteUri1 = "https://a/b/1.patch"; - String containerLocal1 = "."; - String remoteUri2 = "s3a://a/s3dir"; - String containerLocal2 = "/opt/mys3dir"; - String localUri1 = "/temp/script2"; - String containerLocal3 = "./"; - - MockClientContext mockClientContext = - YarnServiceCliTestUtils.getMockClientContext(); - SubmarineConfiguration submarineConf = new SubmarineConfiguration(); - RemoteDirectoryManager spyRdm = - spy(mockClientContext.getRemoteDirectoryManager()); - mockClientContext.setRemoteDirectoryMgr(spyRdm); - /** - * Max 10MB, mock remote will always return file size 100MB. - * */ - submarineConf.set( - SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB, - "10"); - mockClientContext.setSubmarineConfig(submarineConf); - - RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); - - // create remote file in local staging dir to simulate - Path stagingDir = mockClientContext.getRemoteDirectoryManager() - .getJobStagingArea("my-job", true); - File remoteFile1 = new File(stagingDir.toUri().getPath() - + "/" + new Path(remoteUri1).getName()); - remoteFile1.createNewFile(); - File remoteDir1 = new File(stagingDir.toUri().getPath() - + "/" + new Path(remoteUri2).getName()); - remoteDir1.mkdir(); - - File remoteDir1File1 = new File(remoteDir1, "afile"); - remoteDir1File1.createNewFile(); - - String fakeLocalDir = System.getProperty("java.io.tmpdir"); - // create local file, we need to put it under local temp dir - File localFile1 = new File(fakeLocalDir, - new Path(localUri1).getName()); - localFile1.createNewFile(); - - Assert.assertTrue(remoteFile1.exists()); - Assert.assertTrue(remoteDir1.exists()); - Assert.assertTrue(remoteDir1File1.exists()); - - String suffix1 = "_" + remoteDir1.lastModified() - + "-" + remoteDir1.length(); - try { - runJobCli = new RunJobCli(mockClientContext); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", - "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", - "--localization", - remoteUri1 + ":" + containerLocal1}); - } catch (IOException e) { - // Shouldn't have exception because it's within file size limit - Assert.assertFalse(true); - } - // we should download because fail fast - verify(spyRdm, times(1)).copyRemoteToLocal( - anyString(), anyString()); - try { - // reset - reset(spyRdm); - runJobCli = new RunJobCli(mockClientContext); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", - "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", - "--localization", - remoteUri1 + ":" + containerLocal1, - "--localization", - remoteUri2 + ":" + containerLocal2, - "--localization", - localFile1.getAbsolutePath() + ":" + containerLocal3}); - } catch (IOException e) { - Assert.assertTrue(e.getMessage() - .contains("104857600 exceeds configured max size:10485760")); - // we shouldn't do any download because fail fast - verify(spyRdm, times(0)).copyRemoteToLocal( - anyString(), anyString()); - } - - try { - runJobCli = new RunJobCli(mockClientContext); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", - "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", - "--localization", - localFile1.getAbsolutePath() + ":" + containerLocal3}); - } catch (IOException e) { - Assert.assertTrue(e.getMessage() - .contains("104857600 exceeds configured max size:10485760")); - // we shouldn't do any download because fail fast - verify(spyRdm, times(0)).copyRemoteToLocal( - anyString(), anyString()); - } - - localFile1.delete(); - remoteDir1File1.delete(); - remoteFile1.delete(); - remoteDir1.delete(); - } - - /** - * Test remote Uri doesn't exist. - * */ - @Test - public void testRunJobWithNonExistRemoteUri() throws Exception { - String remoteUri1 = "hdfs:///a/b/1.patch"; - String containerLocal1 = "."; - String localUri1 = "/a/b/c"; - String containerLocal2 = "./"; - MockClientContext mockClientContext = - YarnServiceCliTestUtils.getMockClientContext(); - - RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); - - try { - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", - "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", - "--localization", - remoteUri1 + ":" + containerLocal1}); - } catch (IOException e) { - Assert.assertTrue(e.getMessage() - .contains("doesn't exists")); - } - - try { - runJobCli = new RunJobCli(mockClientContext); - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", - "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", - "--localization", - localUri1 + ":" + containerLocal2}); - } catch (IOException e) { - Assert.assertTrue(e.getMessage() - .contains("doesn't exists")); - } - } - - /** - * Test local dir - * --localization /user/yarn/mydir:./mydir1 - * --localization /user/yarn/mydir2:/opt/dir2:rw - * --localization /user/yarn/mydir2:. - */ - @Test - public void testRunJobWithLocalDirLocalization() throws Exception { - String fakeLocalDir = System.getProperty("java.io.tmpdir"); - String localUrl = "/user/yarn/mydir"; - String containerPath = "./mydir1"; - String localUrl2 = "/user/yarn/mydir2"; - String containPath2 = "/opt/dir2"; - String containerPath3 = "."; - - MockClientContext mockClientContext = - YarnServiceCliTestUtils.getMockClientContext(); - RunJobCli runJobCli = new RunJobCli(mockClientContext); - Assert.assertFalse(SubmarineLogs.isVerbose()); - - RemoteDirectoryManager spyRdm = - spy(mockClientContext.getRemoteDirectoryManager()); - mockClientContext.setRemoteDirectoryMgr(spyRdm); - // create local file - File localDir1 = new File(fakeLocalDir, - localUrl); - localDir1.mkdirs(); - File temp1 = new File(localDir1.getAbsolutePath() + "/1.py"); - File temp2 = new File(localDir1.getAbsolutePath() + "/2.py"); - temp1.createNewFile(); - temp2.createNewFile(); - - File localDir2 = new File(fakeLocalDir, - localUrl2); - localDir2.mkdirs(); - File temp3 = new File(localDir1.getAbsolutePath() + "/3.py"); - File temp4 = new File(localDir1.getAbsolutePath() + "/4.py"); - temp3.createNewFile(); - temp4.createNewFile(); - - Assert.assertTrue(localDir1.exists()); - Assert.assertTrue(localDir2.exists()); - - String suffix1 = "_" + localDir1.lastModified() - + "-" + localDir1.length(); - String suffix2 = "_" + localDir2.lastModified() - + "-" + localDir2.length(); - - runJobCli.run( - new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0", - "--input_path", "s3://input", "--checkpoint_path", "s3://output", - "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", - "python run-job.py", "--worker_resources", "memory=2048M,vcores=2", - "--ps_resources", "memory=4096M,vcores=4", "--ps_docker_image", - "ps.image", "--worker_docker_image", "worker.image", - "--ps_launch_cmd", "python run-ps.py", "--verbose", - "--localization", - fakeLocalDir + localUrl + ":" + containerPath, - "--localization", - fakeLocalDir + localUrl2 + ":" + containPath2 + ":rw", - "--localization", - fakeLocalDir + localUrl2 + ":" + containerPath3}); - - Service serviceSpec = getServiceSpecFromJobSubmitter( - runJobCli.getJobSubmitter()); - Assert.assertEquals(3, serviceSpec.getComponents().size()); - - // we shouldn't do any download - verify(spyRdm, times(0)).copyRemoteToLocal( - anyString(), anyString()); - - // Ensure local original files are not deleted - Assert.assertTrue(localDir1.exists()); - Assert.assertTrue(localDir2.exists()); - - // Ensure zip file are deleted - Assert.assertFalse(new File(System.getProperty("java.io.tmpdir") - + "/" + new Path(localUrl).getName() - + suffix1 + ".zip").exists()); - Assert.assertFalse(new File(System.getProperty("java.io.tmpdir") - + "/" + new Path(localUrl2).getName() - + suffix2 + ".zip").exists()); - - // Ensure dirs will be zipped and localized - List files = serviceSpec.getConfiguration().getFiles(); - Assert.assertEquals(3, files.size()); - ConfigFile file = files.get(0); - Path stagingDir = mockClientContext.getRemoteDirectoryManager() - .getJobStagingArea("my-job", true); - Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType()); - String expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(localUrl).getName() + suffix1 + ".zip"; - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - String expectedDstFileName = new Path(containerPath).getName(); - Assert.assertEquals(expectedDstFileName, file.getDestFile()); - - file = files.get(1); - Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType()); - expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(localUrl2).getName() + suffix2 + ".zip"; - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - expectedDstFileName = new Path(containPath2).getName(); - Assert.assertEquals(expectedDstFileName, file.getDestFile()); - - file = files.get(2); - Assert.assertEquals(ConfigFile.TypeEnum.ARCHIVE, file.getType()); - expectedSrcLocalization = stagingDir.toUri().getPath() - + "/" + new Path(localUrl2).getName() + suffix2 + ".zip"; - Assert.assertEquals(expectedSrcLocalization, - new Path(file.getSrcFile()).toUri().getPath()); - expectedDstFileName = new Path(localUrl2).getName(); - Assert.assertEquals(expectedDstFileName, file.getDestFile()); - - // Ensure mounts env value is correct - String env = serviceSpec.getConfiguration().getEnv() - .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS"); - String expectedMounts = new Path(containPath2).getName() - + ":" + containPath2 + ":rw"; - - Assert.assertTrue(env.contains(expectedMounts)); - - temp1.delete(); - temp2.delete(); - temp3.delete(); - temp4.delete(); - localDir2.delete(); - localDir1.delete(); - } - /** * Test zip function. * A dir "/user/yarn/mydir" has two files and one subdir @@ -1169,52 +565,32 @@ public void testRunJobWithLocalDirLocalization() throws Exception { @Test public void testYarnServiceSubmitterZipFunction() throws Exception { - MockClientContext mockClientContext = - YarnServiceCliTestUtils.getMockClientContext(); - RunJobCli runJobCli = new RunJobCli(mockClientContext); - YarnServiceJobSubmitter submitter = - (YarnServiceJobSubmitter)mockClientContext - .getRuntimeFactory().getJobSubmitterInstance(); - String fakeLocalDir = System.getProperty("java.io.tmpdir"); String localUrl = "/user/yarn/mydir"; String localSubDirName = "subdir1"; + // create local file - File localDir1 = new File(fakeLocalDir, - localUrl); - localDir1.mkdirs(); - File temp1 = new File(localDir1.getAbsolutePath() + "/1.py"); - File temp2 = new File(localDir1.getAbsolutePath() + "/2.py"); - temp1.createNewFile(); - temp2.createNewFile(); + File localDir1 = testCommons.getFileUtils().createDirInTempDir(localUrl); + testCommons.getFileUtils().createFileInDir(localDir1, "1.py"); + testCommons.getFileUtils().createFileInDir(localDir1, "2.py"); + File localSubDir = + testCommons.getFileUtils().createDirectory(localDir1, localSubDirName); + testCommons.getFileUtils().createFileInDir(localSubDir, "3.py"); - File localSubDir = new File(localDir1.getAbsolutePath(), localSubDirName); - localSubDir.mkdir(); - File temp3 = new File(localSubDir.getAbsolutePath(), "3.py"); - temp3.createNewFile(); - - - String zipFilePath = submitter.zipDir(localDir1.getAbsolutePath(), - fakeLocalDir + "/user/yarn/mydir.zip"); + String tempDir = localDir1.getParent(); + String zipFilePath = ZipUtilities.zipDir(localDir1.getAbsolutePath(), + new File(tempDir, "mydir.zip").getAbsolutePath()); File zipFile = new File(zipFilePath); - File unzipTargetDir = new File(fakeLocalDir, "unzipDir"); + File unzipTargetDir = new File(tempDir, "unzipDir"); FileUtil.unZip(zipFile, unzipTargetDir); - Assert.assertTrue( - new File(fakeLocalDir + "/unzipDir/1.py").exists()); - Assert.assertTrue( - new File(fakeLocalDir + "/unzipDir/2.py").exists()); - Assert.assertTrue( - new File(fakeLocalDir + "/unzipDir/subdir1").exists()); - Assert.assertTrue( - new File(fakeLocalDir + "/unzipDir/subdir1/3.py").exists()); - - zipFile.delete(); - unzipTargetDir.delete(); - temp1.delete(); - temp2.delete(); - temp3.delete(); - localSubDir.delete(); - localDir1.delete(); + assertTrue( + new File(tempDir + "/unzipDir/1.py").exists()); + assertTrue( + new File(tempDir + "/unzipDir/2.py").exists()); + assertTrue( + new File(tempDir + "/unzipDir/subdir1").exists()); + assertTrue( + new File(tempDir + "/unzipDir/subdir1/3.py").exists()); } } diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliCommons.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliCommons.java new file mode 100644 index 00000000000..94e2c378672 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliCommons.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.client.cli.yarnservice; + +import org.apache.hadoop.yarn.client.api.AppAdminClient; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.service.api.records.Service; +import org.apache.hadoop.yarn.submarine.FileUtilitiesForTests; +import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; +import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceJobSubmitter; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils; + +import java.io.IOException; + +import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Common operations shared with test classes using Run job-related actions. + */ +public class TestYarnServiceRunJobCliCommons { + static final String DEFAULT_JOB_NAME = "my-job"; + static final String DEFAULT_DOCKER_IMAGE = "tf-docker:1.1.0"; + static final String DEFAULT_INPUT_PATH = "s3://input"; + static final String DEFAULT_CHECKPOINT_PATH = "s3://output"; + static final String DEFAULT_WORKER_DOCKER_IMAGE = "worker.image"; + static final String DEFAULT_PS_DOCKER_IMAGE = "ps.image"; + static final String DEFAULT_WORKER_LAUNCH_CMD = "python run-job.py"; + static final String DEFAULT_PS_LAUNCH_CMD = "python run-ps.py"; + static final String DEFAULT_TENSORBOARD_RESOURCES = "memory=2G,vcores=2"; + static final String DEFAULT_WORKER_RESOURCES = "memory=2048M,vcores=2"; + static final String DEFAULT_PS_RESOURCES = "memory=4096M,vcores=4"; + static final String DEFAULT_TENSORBOARD_DOCKER_IMAGE = "tb_docker_image:001"; + + private FileUtilitiesForTests fileUtils = new FileUtilitiesForTests(); + + void setup() throws IOException, YarnException { + SubmarineLogs.verboseOff(); + AppAdminClient serviceClient = mock(AppAdminClient.class); + when(serviceClient.actionLaunch(any(String.class), any(String.class), + any(Long.class), any(String.class))).thenReturn(EXIT_SUCCESS); + when(serviceClient.getStatusString(any(String.class))).thenReturn( + "{\"id\": \"application_1234_1\"}"); + YarnServiceUtils.setStubServiceClient(serviceClient); + + fileUtils.setup(); + } + + void teardown() throws IOException { + fileUtils.teardown(); + } + + FileUtilitiesForTests getFileUtils() { + return fileUtils; + } + + Service getServiceSpecFromJobSubmitter(JobSubmitter jobSubmitter) { + return ((YarnServiceJobSubmitter) jobSubmitter).getServiceWrapper() + .getService(); + } + +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliLocalization.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliLocalization.java new file mode 100644 index 00000000000..9bee3024616 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/client/cli/yarnservice/TestYarnServiceRunJobCliLocalization.java @@ -0,0 +1,599 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.client.cli.yarnservice; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.service.api.records.ConfigFile; +import org.apache.hadoop.yarn.service.api.records.Service; +import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli; +import org.apache.hadoop.yarn.submarine.common.MockClientContext; +import org.apache.hadoop.yarn.submarine.common.conf.SubmarineConfiguration; +import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; +import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; +import static org.apache.hadoop.yarn.submarine.client.cli.yarnservice.TestYarnServiceRunJobCliCommons.*; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Class to test YarnService localization feature with the Run job CLI action. + */ +public class TestYarnServiceRunJobCliLocalization { + private static final String ZIP_EXTENSION = ".zip"; + private TestYarnServiceRunJobCliCommons testCommons = + new TestYarnServiceRunJobCliCommons(); + private MockClientContext mockClientContext; + private RemoteDirectoryManager spyRdm; + + @Before + public void before() throws IOException, YarnException { + testCommons.setup(); + mockClientContext = YarnServiceCliTestUtils.getMockClientContext(); + spyRdm = setupSpyRemoteDirManager(); + } + + @After + public void cleanup() throws IOException { + testCommons.teardown(); + } + + private ParamBuilderForTest createCommonParamsBuilder() { + return ParamBuilderForTest.create() + .withJobName(DEFAULT_JOB_NAME) + .withDockerImage(DEFAULT_DOCKER_IMAGE) + .withInputPath(DEFAULT_INPUT_PATH) + .withCheckpointPath(DEFAULT_CHECKPOINT_PATH) + .withNumberOfWorkers(3) + .withWorkerDockerImage(DEFAULT_WORKER_DOCKER_IMAGE) + .withWorkerLaunchCommand(DEFAULT_WORKER_LAUNCH_CMD) + .withWorkerResources(DEFAULT_WORKER_RESOURCES) + .withNumberOfPs(2) + .withPsDockerImage(DEFAULT_PS_DOCKER_IMAGE) + .withPsLaunchCommand(DEFAULT_PS_LAUNCH_CMD) + .withPsResources(DEFAULT_PS_RESOURCES) + .withVerbose(); + } + + private void assertFilesAreDeleted(File... files) { + for (File file : files) { + assertFalse("File should be deleted: " + file.getAbsolutePath(), + file.exists()); + } + } + + private RemoteDirectoryManager setupSpyRemoteDirManager() { + RemoteDirectoryManager spyRdm = + spy(mockClientContext.getRemoteDirectoryManager()); + mockClientContext.setRemoteDirectoryMgr(spyRdm); + return spyRdm; + } + + private Path getStagingDir() throws IOException { + return mockClientContext.getRemoteDirectoryManager() + .getJobStagingArea(DEFAULT_JOB_NAME, true); + } + + private RunJobCli createRunJobCliWithoutVerboseAssertion() { + return new RunJobCli(mockClientContext); + } + + private RunJobCli createRunJobCli() { + RunJobCli runJobCli = new RunJobCli(mockClientContext); + assertFalse(SubmarineLogs.isVerbose()); + return runJobCli; + } + + private String getFilePath(String localUrl, Path stagingDir) { + return stagingDir.toUri().getPath() + + "/" + new Path(localUrl).getName(); + } + + private String getFilePathWithSuffix(Path stagingDir, String localUrl, + String suffix) { + return stagingDir.toUri().getPath() + "/" + new Path(localUrl).getName() + + suffix; + } + + private void assertConfigFile(ConfigFile expected, ConfigFile actual) { + assertEquals("ConfigFile does not equal to expected!", expected, actual); + } + + private void assertNumberOfLocalizations(List files, + int expected) { + assertEquals("Number of localizations is not the expected!", expected, + files.size()); + } + + private void verifyRdmCopyToRemoteLocalCalls(int expectedCalls) + throws IOException { + verify(spyRdm, times(expectedCalls)).copyRemoteToLocal(anyString(), + anyString()); + } + + /** + * Basic test. + * In one hand, create local temp file/dir for hdfs URI in + * local staging dir. + * In the other hand, use MockRemoteDirectoryManager mock + * implementation when check FileStatus or exists of HDFS file/dir + * --localization hdfs:///user/yarn/script1.py:. + * --localization /temp/script2.py:./ + * --localization /temp/script2.py:/opt/script.py + */ + @Test + public void testRunJobWithBasicLocalization() throws Exception { + String remoteUrl = "hdfs:///user/yarn/script1.py"; + String containerLocal1 = "."; + String localUrl = "/temp/script2.py"; + String containerLocal2 = "./"; + String containerLocal3 = "/opt/script.py"; + // Create local file, we need to put it under local temp dir + File localFile1 = testCommons.getFileUtils().createFileInTempDir(localUrl); + + // create remote file in local staging dir to simulate HDFS + Path stagingDir = getStagingDir(); + testCommons.getFileUtils().createFileInDir(stagingDir, remoteUrl); + + String[] params = createCommonParamsBuilder() + .withLocalization(remoteUrl, containerLocal1) + .withLocalization(localFile1.getAbsolutePath(), containerLocal2) + .withLocalization(localFile1.getAbsolutePath(), containerLocal3) + .build(); + RunJobCli runJobCli = createRunJobCli(); + runJobCli.run(params); + Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter( + runJobCli.getJobSubmitter()); + assertNumberOfServiceComponents(serviceSpec, 3); + + // No remote dir and HDFS file exists. + // Ensure download never happened. + verifyRdmCopyToRemoteLocalCalls(0); + // Ensure local original files are not deleted + assertTrue(localFile1.exists()); + + List files = serviceSpec.getConfiguration().getFiles(); + assertNumberOfLocalizations(files, 3); + + ConfigFile expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC); + expectedConfigFile.setSrcFile(remoteUrl); + expectedConfigFile.setDestFile(new Path(remoteUrl).getName()); + assertConfigFile(expectedConfigFile, files.get(0)); + + expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC); + expectedConfigFile.setSrcFile(getFilePath(localUrl, stagingDir)); + expectedConfigFile.setDestFile(new Path(localUrl).getName()); + assertConfigFile(expectedConfigFile, files.get(1)); + + expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC); + expectedConfigFile.setSrcFile(getFilePath(localUrl, stagingDir)); + expectedConfigFile.setDestFile(new Path(containerLocal3).getName()); + assertConfigFile(expectedConfigFile, files.get(2)); + + // Ensure env value is correct + String env = serviceSpec.getConfiguration().getEnv() + .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS"); + String expectedMounts = new Path(containerLocal3).getName() + + ":" + containerLocal3 + ":rw"; + assertTrue(env.contains(expectedMounts)); + } + + private void assertNumberOfServiceComponents(Service serviceSpec, + int expected) { + assertEquals(expected, serviceSpec.getComponents().size()); + } + + /** + * Non HDFS remote URI test. + * --localization https://a/b/1.patch:. + * --localization s3a://a/dir:/opt/mys3dir + */ + @Test + public void testRunJobWithNonHDFSRemoteLocalization() throws Exception { + String remoteUri1 = "https://a/b/1.patch"; + String containerLocal1 = "."; + String remoteUri2 = "s3a://a/s3dir"; + String containerLocal2 = "/opt/mys3dir"; + + // create remote file in local staging dir to simulate HDFS + Path stagingDir = getStagingDir(); + testCommons.getFileUtils().createFileInDir(stagingDir, remoteUri1); + File remoteDir1 = + testCommons.getFileUtils().createDirectory(stagingDir, remoteUri2); + testCommons.getFileUtils().createFileInDir(remoteDir1, "afile"); + + String suffix1 = "_" + remoteDir1.lastModified() + + "-" + mockClientContext.getRemoteDirectoryManager() + .getRemoteFileSize(remoteUri2); + + String[] params = createCommonParamsBuilder() + .withLocalization(remoteUri1, containerLocal1) + .withLocalization(remoteUri2, containerLocal2) + .build(); + RunJobCli runJobCli = createRunJobCli(); + runJobCli.run(params); + Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter( + runJobCli.getJobSubmitter()); + assertNumberOfServiceComponents(serviceSpec, 3); + + // Ensure download remote dir 2 times + verifyRdmCopyToRemoteLocalCalls(2); + + // Ensure downloaded temp files are deleted + assertFilesAreDeleted( + testCommons.getFileUtils().getTempFileWithName(remoteUri1), + testCommons.getFileUtils().getTempFileWithName(remoteUri2)); + + // Ensure zip file are deleted + assertFilesAreDeleted( + testCommons.getFileUtils() + .getTempFileWithName(remoteUri2 + "_" + suffix1 + ZIP_EXTENSION)); + + List files = serviceSpec.getConfiguration().getFiles(); + assertNumberOfLocalizations(files, 2); + + ConfigFile expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.STATIC); + expectedConfigFile.setSrcFile(getFilePath(remoteUri1, stagingDir)); + expectedConfigFile.setDestFile(new Path(remoteUri1).getName()); + assertConfigFile(expectedConfigFile, files.get(0)); + + expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE); + expectedConfigFile.setSrcFile( + getFilePathWithSuffix(stagingDir, remoteUri2, suffix1 + ZIP_EXTENSION)); + expectedConfigFile.setDestFile(new Path(containerLocal2).getName()); + assertConfigFile(expectedConfigFile, files.get(1)); + + // Ensure env value is correct + String env = serviceSpec.getConfiguration().getEnv() + .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS"); + String expectedMounts = new Path(remoteUri2).getName() + + ":" + containerLocal2 + ":rw"; + assertTrue(env.contains(expectedMounts)); + } + + /** + * Test HDFS dir localization. + * --localization hdfs:///user/yarn/mydir:./mydir1 + * --localization hdfs:///user/yarn/mydir2:/opt/dir2:rw + * --localization hdfs:///user/yarn/mydir:. + * --localization hdfs:///user/yarn/mydir2:./ + */ + @Test + public void testRunJobWithHdfsDirLocalization() throws Exception { + String remoteUrl = "hdfs:///user/yarn/mydir"; + String containerPath = "./mydir1"; + String remoteUrl2 = "hdfs:///user/yarn/mydir2"; + String containerPath2 = "/opt/dir2"; + String containerPath3 = "."; + String containerPath4 = "./"; + + // create remote file in local staging dir to simulate HDFS + Path stagingDir = getStagingDir(); + File remoteDir1 = + testCommons.getFileUtils().createDirectory(stagingDir, remoteUrl); + testCommons.getFileUtils().createFileInDir(remoteDir1, "1.py"); + testCommons.getFileUtils().createFileInDir(remoteDir1, "2.py"); + + File remoteDir2 = + testCommons.getFileUtils().createDirectory(stagingDir, remoteUrl2); + testCommons.getFileUtils().createFileInDir(remoteDir2, "3.py"); + testCommons.getFileUtils().createFileInDir(remoteDir2, "4.py"); + + String suffix1 = "_" + remoteDir1.lastModified() + + "-" + mockClientContext.getRemoteDirectoryManager() + .getRemoteFileSize(remoteUrl); + String suffix2 = "_" + remoteDir2.lastModified() + + "-" + mockClientContext.getRemoteDirectoryManager() + .getRemoteFileSize(remoteUrl2); + + String[] params = createCommonParamsBuilder() + .withLocalization(remoteUrl, containerPath) + .withLocalization(remoteUrl2, containerPath2) + .withLocalization(remoteUrl, containerPath3) + .withLocalization(remoteUrl2, containerPath4) + .build(); + RunJobCli runJobCli = createRunJobCli(); + runJobCli.run(params); + Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter( + runJobCli.getJobSubmitter()); + assertNumberOfServiceComponents(serviceSpec, 3); + + // Ensure download remote dir 4 times + verifyRdmCopyToRemoteLocalCalls(4); + + // Ensure downloaded temp files are deleted + assertFilesAreDeleted( + testCommons.getFileUtils().getTempFileWithName(remoteUrl), + testCommons.getFileUtils().getTempFileWithName(remoteUrl2)); + + // Ensure zip file are deleted + assertFilesAreDeleted( + testCommons.getFileUtils() + .getTempFileWithName(remoteUrl + suffix1 + ZIP_EXTENSION), + testCommons.getFileUtils() + .getTempFileWithName(remoteUrl2 + suffix2 + ZIP_EXTENSION)); + + // Ensure files will be localized + List files = serviceSpec.getConfiguration().getFiles(); + assertNumberOfLocalizations(files, 4); + + ConfigFile expectedConfigFile = new ConfigFile(); + // The hdfs dir should be download and compress and let YARN to uncompress + expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE); + expectedConfigFile.setSrcFile( + getFilePathWithSuffix(stagingDir, remoteUrl, suffix1 + ZIP_EXTENSION)); + // Relative path in container, but not "." or "./". Use its own name + expectedConfigFile.setDestFile(new Path(containerPath).getName()); + assertConfigFile(expectedConfigFile, files.get(0)); + + expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE); + expectedConfigFile.setSrcFile( + getFilePathWithSuffix(stagingDir, remoteUrl2, suffix2 + ZIP_EXTENSION)); + expectedConfigFile.setDestFile(new Path(containerPath2).getName()); + assertConfigFile(expectedConfigFile, files.get(1)); + + expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE); + expectedConfigFile.setSrcFile( + getFilePathWithSuffix(stagingDir, remoteUrl, suffix1 + ZIP_EXTENSION)); + // Relative path in container ".", use remote path name + expectedConfigFile.setDestFile(new Path(remoteUrl).getName()); + assertConfigFile(expectedConfigFile, files.get(2)); + + expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE); + expectedConfigFile.setSrcFile( + getFilePathWithSuffix(stagingDir, remoteUrl2, suffix2 + ZIP_EXTENSION)); + // Relative path in container ".", use remote path name + expectedConfigFile.setDestFile(new Path(remoteUrl2).getName()); + assertConfigFile(expectedConfigFile, files.get(3)); + + // Ensure mounts env value is correct. Add one mount string + String env = serviceSpec.getConfiguration().getEnv() + .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS"); + + String expectedMounts = + new Path(containerPath2).getName() + ":" + containerPath2 + ":rw"; + assertTrue(env.contains(expectedMounts)); + } + + /** + * Test if file/dir to be localized whose size exceeds limit. + * Max 10MB in configuration, mock remote will + * always return file size 100MB. + * This configuration will fail the job which has remoteUri + * But don't impact local dir/file + * + * --localization https://a/b/1.patch:. + * --localization s3a://a/dir:/opt/mys3dir + * --localization /temp/script2.py:./ + */ + @Test + public void testRunJobRemoteUriExceedLocalizationSize() throws Exception { + String remoteUri1 = "https://a/b/1.patch"; + String containerLocal1 = "."; + String remoteUri2 = "s3a://a/s3dir"; + String containerLocal2 = "/opt/mys3dir"; + String localUri1 = "/temp/script2"; + String containerLocal3 = "./"; + + SubmarineConfiguration submarineConf = new SubmarineConfiguration(); + + // Max 10MB, mock remote will always return file size 100MB. + submarineConf.set( + SubmarineConfiguration.LOCALIZATION_MAX_ALLOWED_FILE_SIZE_MB, + "10"); + mockClientContext.setSubmarineConfig(submarineConf); + + assertFalse(SubmarineLogs.isVerbose()); + + // create remote file in local staging dir to simulate + Path stagingDir = getStagingDir(); + testCommons.getFileUtils().createFileInDir(stagingDir, remoteUri1); + File remoteDir1 = + testCommons.getFileUtils().createDirectory(stagingDir, remoteUri2); + testCommons.getFileUtils().createFileInDir(remoteDir1, "afile"); + + // create local file, we need to put it under local temp dir + File localFile1 = testCommons.getFileUtils().createFileInTempDir(localUri1); + + try { + RunJobCli runJobCli = createRunJobCli(); + String[] params = createCommonParamsBuilder() + .withLocalization(remoteUri1, containerLocal1) + .build(); + runJobCli.run(params); + } catch (IOException e) { + // Shouldn't have exception because it's within file size limit + fail(); + } + // we should download because fail fast + verifyRdmCopyToRemoteLocalCalls(1); + try { + String[] params = createCommonParamsBuilder() + .withLocalization(remoteUri1, containerLocal1) + .withLocalization(remoteUri2, containerLocal2) + .withLocalization(localFile1.getAbsolutePath(), containerLocal3) + .build(); + + reset(spyRdm); + RunJobCli runJobCli = createRunJobCliWithoutVerboseAssertion(); + runJobCli.run(params); + } catch (IOException e) { + assertTrue(e.getMessage() + .contains("104857600 exceeds configured max size:10485760")); + // we shouldn't do any download because fail fast + verifyRdmCopyToRemoteLocalCalls(0); + } + + try { + String[] params = createCommonParamsBuilder() + .withLocalization(localFile1.getAbsolutePath(), containerLocal3) + .build(); + RunJobCli runJobCli = createRunJobCliWithoutVerboseAssertion(); + runJobCli.run(params); + } catch (IOException e) { + assertTrue(e.getMessage() + .contains("104857600 exceeds configured max size:10485760")); + // we shouldn't do any download because fail fast + verifyRdmCopyToRemoteLocalCalls(0); + } + } + + /** + * Test remote Uri doesn't exist. + * */ + @Test + public void testRunJobWithNonExistRemoteUri() throws Exception { + String remoteUri1 = "hdfs:///a/b/1.patch"; + String containerLocal1 = "."; + String localUri1 = "/a/b/c"; + String containerLocal2 = "./"; + + try { + String[] params = createCommonParamsBuilder() + .withLocalization(remoteUri1, containerLocal1) + .build(); + RunJobCli runJobCli = createRunJobCli(); + runJobCli.run(params); + } catch (IOException e) { + assertTrue(e.getMessage().contains("doesn't exists")); + } + + try { + String[] params = createCommonParamsBuilder() + .withLocalization(localUri1, containerLocal2) + .build(); + RunJobCli runJobCli = createRunJobCliWithoutVerboseAssertion(); + runJobCli.run(params); + } catch (IOException e) { + assertTrue(e.getMessage().contains("doesn't exists")); + } + } + + /** + * Test local dir + * --localization /user/yarn/mydir:./mydir1 + * --localization /user/yarn/mydir2:/opt/dir2:rw + * --localization /user/yarn/mydir2:. + */ + @Test + public void testRunJobWithLocalDirLocalization() throws Exception { + String localUrl = "/user/yarn/mydir"; + String containerPath = "./mydir1"; + String localUrl2 = "/user/yarn/mydir2"; + String containerPath2 = "/opt/dir2"; + String containerPath3 = "."; + + // create local file + File localDir1 = testCommons.getFileUtils().createDirInTempDir(localUrl); + testCommons.getFileUtils().createFileInDir(localDir1, "1.py"); + testCommons.getFileUtils().createFileInDir(localDir1, "2.py"); + + File localDir2 = testCommons.getFileUtils().createDirInTempDir(localUrl2); + testCommons.getFileUtils().createFileInDir(localDir2, "3.py"); + testCommons.getFileUtils().createFileInDir(localDir2, "4.py"); + + String suffix1 = "_" + localDir1.lastModified() + + "-" + localDir1.length(); + String suffix2 = "_" + localDir2.lastModified() + + "-" + localDir2.length(); + + String[] params = createCommonParamsBuilder() + .withLocalization(localDir1.getAbsolutePath(), containerPath) + .withLocalization(localDir2.getAbsolutePath(), containerPath2) + .withLocalization(localDir2.getAbsolutePath(), containerPath3) + .build(); + RunJobCli runJobCli = createRunJobCli(); + runJobCli.run(params); + + Service serviceSpec = testCommons.getServiceSpecFromJobSubmitter( + runJobCli.getJobSubmitter()); + assertNumberOfServiceComponents(serviceSpec, 3); + + // we shouldn't do any download + verifyRdmCopyToRemoteLocalCalls(0); + + // Ensure local original files are not deleted + assertTrue(localDir1.exists()); + assertTrue(localDir2.exists()); + + // Ensure zip file are deleted + assertFalse( + testCommons.getFileUtils() + .getTempFileWithName(localUrl + suffix1 + ZIP_EXTENSION) + .exists()); + assertFalse( + testCommons.getFileUtils() + .getTempFileWithName(localUrl2 + suffix2 + ZIP_EXTENSION) + .exists()); + + // Ensure dirs will be zipped and localized + List files = serviceSpec.getConfiguration().getFiles(); + assertNumberOfLocalizations(files, 3); + + Path stagingDir = getStagingDir(); + ConfigFile expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE); + expectedConfigFile.setSrcFile( + getFilePathWithSuffix(stagingDir, localUrl, suffix1 + ZIP_EXTENSION)); + expectedConfigFile.setDestFile(new Path(containerPath).getName()); + assertConfigFile(expectedConfigFile, files.get(0)); + + expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE); + expectedConfigFile.setSrcFile( + getFilePathWithSuffix(stagingDir, localUrl2, suffix2 + ZIP_EXTENSION)); + expectedConfigFile.setDestFile(new Path(containerPath2).getName()); + assertConfigFile(expectedConfigFile, files.get(1)); + + expectedConfigFile = new ConfigFile(); + expectedConfigFile.setType(ConfigFile.TypeEnum.ARCHIVE); + expectedConfigFile.setSrcFile( + getFilePathWithSuffix(stagingDir, localUrl2, suffix2 + ZIP_EXTENSION)); + expectedConfigFile.setDestFile(new Path(localUrl2).getName()); + assertConfigFile(expectedConfigFile, files.get(2)); + + // Ensure mounts env value is correct + String env = serviceSpec.getConfiguration().getEnv() + .get("YARN_CONTAINER_RUNTIME_DOCKER_MOUNTS"); + String expectedMounts = new Path(containerPath2).getName() + + ":" + containerPath2 + ":rw"; + + assertTrue(env.contains(expectedMounts)); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestServiceWrapper.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestServiceWrapper.java new file mode 100644 index 00000000000..cd5c05c82b1 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestServiceWrapper.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; + +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.service.api.records.Service; +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Class to test the {@link ServiceWrapper}. + */ +public class TestServiceWrapper { + private AbstractComponent createMockAbstractComponent(Component mockComponent, + String componentName, String localScriptFile) throws IOException { + when(mockComponent.getName()).thenReturn(componentName); + + AbstractComponent mockAbstractComponent = mock(AbstractComponent.class); + when(mockAbstractComponent.createComponent()).thenReturn(mockComponent); + when(mockAbstractComponent.getLocalScriptFile()) + .thenReturn(localScriptFile); + return mockAbstractComponent; + } + + @Test + public void testWithSingleComponent() throws IOException { + Service mockService = mock(Service.class); + ServiceWrapper serviceWrapper = new ServiceWrapper(mockService); + + Component mockComponent = mock(Component.class); + AbstractComponent mockAbstractComponent = + createMockAbstractComponent(mockComponent, "testComponent", + "testLocalScriptFile"); + serviceWrapper.addComponent(mockAbstractComponent); + + verify(mockService).addComponent(eq(mockComponent)); + + String launchCommand = + serviceWrapper.getLocalLaunchCommandPathForComponent("testComponent"); + assertEquals("testLocalScriptFile", launchCommand); + } + + @Test + public void testWithMultipleComponent() throws IOException { + Service mockService = mock(Service.class); + ServiceWrapper serviceWrapper = new ServiceWrapper(mockService); + + Component mockComponent1 = mock(Component.class); + AbstractComponent mockAbstractComponent1 = + createMockAbstractComponent(mockComponent1, "testComponent1", + "testLocalScriptFile1"); + + Component mockComponent2 = mock(Component.class); + AbstractComponent mockAbstractComponent2 = + createMockAbstractComponent(mockComponent2, "testComponent2", + "testLocalScriptFile2"); + + serviceWrapper.addComponent(mockAbstractComponent1); + serviceWrapper.addComponent(mockAbstractComponent2); + + verify(mockService).addComponent(eq(mockComponent1)); + verify(mockService).addComponent(eq(mockComponent2)); + + String launchCommand1 = + serviceWrapper.getLocalLaunchCommandPathForComponent("testComponent1"); + assertEquals("testLocalScriptFile1", launchCommand1); + + String launchCommand2 = + serviceWrapper.getLocalLaunchCommandPathForComponent("testComponent2"); + assertEquals("testLocalScriptFile2", launchCommand2); + } + + +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestTFConfigGenerator.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestTFConfigGenerator.java index d7dc8749440..c8b2388814a 100644 --- a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestTFConfigGenerator.java +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/TestTFConfigGenerator.java @@ -14,26 +14,30 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons; import org.codehaus.jettison.json.JSONException; import org.junit.Assert; import org.junit.Test; +/** + * Class to test some functionality of {@link TensorFlowCommons}. + */ public class TestTFConfigGenerator { @Test public void testSimpleDistributedTFConfigGenerator() throws JSONException { - String json = YarnServiceUtils.getTFConfigEnv("worker", 5, 3, "wtan", + String json = TensorFlowCommons.getTFConfigEnv("worker", 5, 3, "wtan", "tf-job-001", "example.com"); String expected = "{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"],\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\",\\\"worker-1.wtan.tf-job-001.example.com:8000\\\",\\\"worker-2.wtan.tf-job-001.example.com:8000\\\",\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"],\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\",\\\"ps-1.wtan.tf-job-001.example.com:8000\\\",\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]},\\\"task\\\":{ \\\"type\\\":\\\"worker\\\", \\\"index\\\":$_TASK_INDEX},\\\"environment\\\":\\\"cloud\\\"}"; Assert.assertEquals(expected, json); - json = YarnServiceUtils.getTFConfigEnv("ps", 5, 3, "wtan", "tf-job-001", + json = TensorFlowCommons.getTFConfigEnv("ps", 5, 3, "wtan", "tf-job-001", "example.com"); expected = "{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"],\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\",\\\"worker-1.wtan.tf-job-001.example.com:8000\\\",\\\"worker-2.wtan.tf-job-001.example.com:8000\\\",\\\"worker-3.wtan.tf-job-001.example.com:8000\\\"],\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\",\\\"ps-1.wtan.tf-job-001.example.com:8000\\\",\\\"ps-2.wtan.tf-job-001.example.com:8000\\\"]},\\\"task\\\":{ \\\"type\\\":\\\"ps\\\", \\\"index\\\":$_TASK_INDEX},\\\"environment\\\":\\\"cloud\\\"}"; Assert.assertEquals(expected, json); - json = YarnServiceUtils.getTFConfigEnv("master", 2, 1, "wtan", "tf-job-001", + json = TensorFlowCommons.getTFConfigEnv("master", 2, 1, "wtan", "tf-job-001", "example.com"); expected = "{\\\"cluster\\\":{\\\"master\\\":[\\\"master-0.wtan.tf-job-001.example.com:8000\\\"],\\\"worker\\\":[\\\"worker-0.wtan.tf-job-001.example.com:8000\\\"],\\\"ps\\\":[\\\"ps-0.wtan.tf-job-001.example.com:8000\\\"]},\\\"task\\\":{ \\\"type\\\":\\\"master\\\", \\\"index\\\":$_TASK_INDEX},\\\"environment\\\":\\\"cloud\\\"}"; diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommandTestHelper.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommandTestHelper.java new file mode 100644 index 00000000000..5275603f8b1 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/AbstractLaunchCommandTestHelper.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.MockClientContext; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand; +import org.junit.Rule; +import org.junit.rules.ExpectedException; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; + +import static junit.framework.TestCase.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * This class is an abstract base class for testing Tensorboard and TensorFlow + * launch commands. + */ +public abstract class AbstractLaunchCommandTestHelper { + private TaskType taskType; + private boolean useTaskTypeOverride; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + private void assertScriptContainsExportedEnvVar(List fileContents, + String varName) { + String expected = String.format("export %s=", varName); + assertScriptContainsLine(fileContents, expected); + } + + public static void assertScriptContainsExportedEnvVarWithValue( + List fileContents, String varName, String value) { + String expected = String.format("export %s=%s", varName, value); + assertScriptContainsLine(fileContents, expected); + } + + public static void assertScriptContainsLine(List fileContents, + String expected) { + String message = String.format( + "File does not contain expected line '%s'!" + " File contents: %s", + expected, Arrays.toString(fileContents.toArray())); + assertTrue(message, fileContents.contains(expected)); + } + + public static void assertScriptContainsLineWithRegex( + List fileContents, + String regex) { + String message = String.format( + "File does not contain expected line '%s'!" + " File contents: %s", + regex, Arrays.toString(fileContents.toArray())); + + for (String line : fileContents) { + if (line.matches(regex)) { + return; + } + } + fail(message); + } + + public static void assertScriptDoesNotContainLine(List fileContents, + String expected) { + String message = String.format( + "File contains unexpected line '%s'!" + " File contents: %s", + expected, Arrays.toString(fileContents.toArray())); + assertFalse(message, fileContents.contains(expected)); + } + + + private AbstractLaunchCommand createLaunchCommandByTaskType(TaskType taskType, + RunJobParameters params) throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + HadoopEnvironmentSetup hadoopEnvSetup = + new HadoopEnvironmentSetup(mockClientContext, fsOperations); + Component component = new Component(); + Configuration yarnConfig = new Configuration(); + + return createLaunchCommandByTaskTypeInternal(taskType, params, + hadoopEnvSetup, component, yarnConfig); + } + + private AbstractLaunchCommand createLaunchCommandByTaskTypeInternal( + TaskType taskType, RunJobParameters params, + HadoopEnvironmentSetup hadoopEnvSetup, Component component, + Configuration yarnConfig) + throws IOException { + if (taskType == TaskType.TENSORBOARD) { + return new TensorBoardLaunchCommand( + hadoopEnvSetup, getTaskType(taskType), component, params); + } else if (taskType == TaskType.WORKER + || taskType == TaskType.PRIMARY_WORKER) { + return new TensorFlowWorkerLaunchCommand( + hadoopEnvSetup, getTaskType(taskType), component, params, yarnConfig); + } else if (taskType == TaskType.PS) { + return new TensorFlowPsLaunchCommand( + hadoopEnvSetup, getTaskType(taskType), component, params, yarnConfig); + } + throw new IllegalStateException("Unknown taskType!"); + } + + public void overrideTaskType(TaskType taskType) { + this.taskType = taskType; + this.useTaskTypeOverride = true; + } + + private TaskType getTaskType(TaskType taskType) { + if (useTaskTypeOverride) { + return this.taskType; + } + return taskType; + } + + public void testHdfsRelatedEnvironmentIsUndefined(TaskType taskType, + RunJobParameters params) throws IOException { + AbstractLaunchCommand launchCommand = + createLaunchCommandByTaskType(taskType, params); + + expectedException.expect(IOException.class); + expectedException + .expectMessage("Failed to detect HDFS-related environments."); + launchCommand.generateLaunchScript(); + } + + public List testHdfsRelatedEnvironmentIsDefined(TaskType taskType, + RunJobParameters params) throws IOException { + AbstractLaunchCommand launchCommand = + createLaunchCommandByTaskType(taskType, params); + + String result = launchCommand.generateLaunchScript(); + assertNotNull(result); + File resultFile = new File(result); + assertTrue(resultFile.exists()); + + List fileContents = Files.readAllLines( + Paths.get(resultFile.toURI()), + Charset.forName("UTF-8")); + + assertEquals("#!/bin/bash", fileContents.get(0)); + assertScriptContainsExportedEnvVar(fileContents, "HADOOP_HOME"); + assertScriptContainsExportedEnvVar(fileContents, "HADOOP_YARN_HOME"); + assertScriptContainsExportedEnvVarWithValue(fileContents, + "HADOOP_HDFS_HOME", "testHdfsHome"); + assertScriptContainsExportedEnvVarWithValue(fileContents, + "HADOOP_COMMON_HOME", "testHdfsHome"); + assertScriptContainsExportedEnvVarWithValue(fileContents, "HADOOP_CONF_DIR", + "$WORK_DIR"); + assertScriptContainsExportedEnvVarWithValue(fileContents, "JAVA_HOME", + "testJavaHome"); + assertScriptContainsExportedEnvVarWithValue(fileContents, "LD_LIBRARY_PATH", + "$LD_LIBRARY_PATH:$JAVA_HOME/lib/amd64/server"); + assertScriptContainsExportedEnvVarWithValue(fileContents, "CLASSPATH", + "`$HADOOP_HDFS_HOME/bin/hadoop classpath --glob`"); + + return fileContents; + } + +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TestLaunchCommandFactory.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TestLaunchCommandFactory.java new file mode 100644 index 00000000000..6351f6160fd --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/command/TestLaunchCommandFactory.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand; +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +/** + * This class is to test the {@link LaunchCommandFactory}. + */ +public class TestLaunchCommandFactory { + + private LaunchCommandFactory createLaunchCommandFactory( + RunJobParameters parameters) { + HadoopEnvironmentSetup hadoopEnvSetup = mock(HadoopEnvironmentSetup.class); + Configuration configuration = mock(Configuration.class); + return new LaunchCommandFactory(hadoopEnvSetup, parameters, configuration); + } + + @Test + public void createLaunchCommandWorkerAndPrimaryWorker() throws IOException { + RunJobParameters parameters = new RunJobParameters(); + parameters.setWorkerLaunchCmd("testWorkerLaunchCommand"); + LaunchCommandFactory launchCommandFactory = createLaunchCommandFactory( + parameters); + Component mockComponent = mock(Component.class); + + AbstractLaunchCommand launchCommand = + launchCommandFactory.createLaunchCommand(TaskType.PRIMARY_WORKER, + mockComponent); + + assertTrue(launchCommand instanceof TensorFlowWorkerLaunchCommand); + + launchCommand = + launchCommandFactory.createLaunchCommand(TaskType.WORKER, + mockComponent); + assertTrue(launchCommand instanceof TensorFlowWorkerLaunchCommand); + + } + + @Test + public void createLaunchCommandPs() throws IOException { + RunJobParameters parameters = new RunJobParameters(); + parameters.setPSLaunchCmd("testPSLaunchCommand"); + LaunchCommandFactory launchCommandFactory = createLaunchCommandFactory( + parameters); + Component mockComponent = mock(Component.class); + + AbstractLaunchCommand launchCommand = + launchCommandFactory.createLaunchCommand(TaskType.PS, + mockComponent); + + assertTrue(launchCommand instanceof TensorFlowPsLaunchCommand); + } + + @Test + public void createLaunchCommandTensorboard() throws IOException { + RunJobParameters parameters = new RunJobParameters(); + parameters.setCheckpointPath("testCheckpointPath"); + LaunchCommandFactory launchCommandFactory = + createLaunchCommandFactory(parameters); + Component mockComponent = mock(Component.class); + + AbstractLaunchCommand launchCommand = + launchCommandFactory.createLaunchCommand(TaskType.TENSORBOARD, + mockComponent); + + assertTrue(launchCommand instanceof TensorBoardLaunchCommand); + } + +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorBoardLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorBoardLaunchCommand.java new file mode 100644 index 00000000000..b854cdfc236 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorBoardLaunchCommand.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command; + +import com.google.common.collect.ImmutableList; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.MockClientContext; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommandTestHelper; +import org.junit.Test; + +import java.io.IOException; +import java.util.List; + +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_HADOOP_HDFS_HOME; +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_JAVA_HOME; + +/** + * This class is to test the {@link TensorBoardLaunchCommand}. + */ +public class TestTensorBoardLaunchCommand extends + AbstractLaunchCommandTestHelper { + + @Test + public void testHdfsRelatedEnvironmentIsUndefined() throws IOException { + RunJobParameters params = new RunJobParameters(); + params.setInputPath("hdfs://bla"); + params.setName("testJobname"); + params.setCheckpointPath("something"); + + testHdfsRelatedEnvironmentIsUndefined(TaskType.TENSORBOARD, + params); + } + + @Test + public void testHdfsRelatedEnvironmentIsDefined() throws IOException { + RunJobParameters params = new RunJobParameters(); + params.setName("testName"); + params.setCheckpointPath("testCheckpointPath"); + params.setInputPath("hdfs://bla"); + params.setEnvars(ImmutableList.of( + DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome", + DOCKER_JAVA_HOME + "=" + "testJavaHome")); + + List fileContents = + testHdfsRelatedEnvironmentIsDefined(TaskType.TENSORBOARD, + params); + assertScriptContainsExportedEnvVarWithValue(fileContents, "LC_ALL", + "C && tensorboard --logdir=testCheckpointPath"); + } + + @Test + public void testCheckpointPathUndefined() throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + HadoopEnvironmentSetup hadoopEnvSetup = + new HadoopEnvironmentSetup(mockClientContext, fsOperations); + + Component component = new Component(); + RunJobParameters params = new RunJobParameters(); + params.setCheckpointPath(null); + + expectedException.expect(NullPointerException.class); + expectedException.expectMessage("CheckpointPath must not be null"); + new TensorBoardLaunchCommand(hadoopEnvSetup, TaskType.TENSORBOARD, + component, params); + } + + @Test + public void testCheckpointPathEmptyString() throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + HadoopEnvironmentSetup hadoopEnvSetup = + new HadoopEnvironmentSetup(mockClientContext, fsOperations); + + Component component = new Component(); + RunJobParameters params = new RunJobParameters(); + params.setCheckpointPath(""); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("CheckpointPath must not be empty"); + new TensorBoardLaunchCommand(hadoopEnvSetup, TaskType.TENSORBOARD, + component, params); + } +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorFlowLaunchCommand.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorFlowLaunchCommand.java new file mode 100644 index 00000000000..fa584c75b9b --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/command/TestTensorFlowLaunchCommand.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command; + +import com.google.common.collect.ImmutableList; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.MockClientContext; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommandTestHelper; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_HADOOP_HDFS_HOME; +import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup.DOCKER_JAVA_HOME; + +/** + * This class is to test the implementors of {@link TensorFlowLaunchCommand}. + */ +@RunWith(Parameterized.class) +public class TestTensorFlowLaunchCommand + extends AbstractLaunchCommandTestHelper { + private TaskType taskType; + + @Parameterized.Parameters + public static Collection data() { + Collection params = new ArrayList<>(); + params.add(new Object[]{TaskType.WORKER }); + params.add(new Object[]{TaskType.PS }); + return params; + } + + public TestTensorFlowLaunchCommand(TaskType taskType) { + this.taskType = taskType; + } + + + private void assertScriptContainsLaunchCommand(List fileContents, + RunJobParameters params) { + String launchCommand = null; + if (taskType == TaskType.WORKER) { + launchCommand = params.getWorkerLaunchCmd(); + } else if (taskType == TaskType.PS) { + launchCommand = params.getPSLaunchCmd(); + } + assertScriptContainsLine(fileContents, launchCommand); + } + + private void setLaunchCommandToParams(RunJobParameters params) { + if (taskType == TaskType.WORKER) { + params.setWorkerLaunchCmd("testWorkerLaunchCommand"); + } else if (taskType == TaskType.PS) { + params.setPSLaunchCmd("testPsLaunchCommand"); + } + } + + private void setLaunchCommandToParams(RunJobParameters params, String value) { + if (taskType == TaskType.WORKER) { + params.setWorkerLaunchCmd(value); + } else if (taskType == TaskType.PS) { + params.setPSLaunchCmd(value); + } + } + + private void assertTypeInJson(List fileContents) { + String expectedType = null; + if (taskType == TaskType.WORKER) { + expectedType = "worker"; + } else if (taskType == TaskType.PS) { + expectedType = "ps"; + } + assertScriptContainsLineWithRegex(fileContents, String.format(".*type.*:" + + ".*%s.*", expectedType)); + } + + private TensorFlowLaunchCommand createTensorFlowLaunchCommandObject( + HadoopEnvironmentSetup hadoopEnvSetup, Configuration yarnConfig, + Component component, RunJobParameters params) throws IOException { + if (taskType == TaskType.WORKER) { + return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType, + component, + params, yarnConfig); + } else if (taskType == TaskType.PS) { + return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component, + params, yarnConfig); + } + throw new IllegalStateException("Unknown tasktype!"); + } + + @Test + public void testHdfsRelatedEnvironmentIsUndefined() throws IOException { + RunJobParameters params = new RunJobParameters(); + params.setInputPath("hdfs://bla"); + params.setName("testJobname"); + setLaunchCommandToParams(params); + + testHdfsRelatedEnvironmentIsUndefined(taskType, params); + } + + @Test + public void testHdfsRelatedEnvironmentIsDefined() throws IOException { + RunJobParameters params = new RunJobParameters(); + params.setName("testName"); + params.setInputPath("hdfs://bla"); + params.setEnvars(ImmutableList.of( + DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome", + DOCKER_JAVA_HOME + "=" + "testJavaHome")); + setLaunchCommandToParams(params); + + List fileContents = + testHdfsRelatedEnvironmentIsDefined(taskType, + params); + assertScriptContainsLaunchCommand(fileContents, params); + assertScriptDoesNotContainLine(fileContents, "export TF_CONFIG="); + } + + @Test + public void testLaunchCommandIsNull() throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + HadoopEnvironmentSetup hadoopEnvSetup = + new HadoopEnvironmentSetup(mockClientContext, fsOperations); + Configuration yarnConfig = new Configuration(); + + Component component = new Component(); + RunJobParameters params = new RunJobParameters(); + params.setName("testName"); + setLaunchCommandToParams(params, null); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("LaunchCommand must not be null or empty"); + TensorFlowLaunchCommand launchCommand = + createTensorFlowLaunchCommandObject(hadoopEnvSetup, yarnConfig, + component, + params); + launchCommand.generateLaunchScript(); + } + + @Test + public void testLaunchCommandIsEmpty() throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + HadoopEnvironmentSetup hadoopEnvSetup = + new HadoopEnvironmentSetup(mockClientContext, fsOperations); + Configuration yarnConfig = new Configuration(); + + Component component = new Component(); + RunJobParameters params = new RunJobParameters(); + params.setName("testName"); + setLaunchCommandToParams(params, ""); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("LaunchCommand must not be null or empty"); + TensorFlowLaunchCommand launchCommand = + createTensorFlowLaunchCommandObject(hadoopEnvSetup, yarnConfig, + component, params); + launchCommand.generateLaunchScript(); + } + + @Test + public void testDistributedTrainingMissingTaskType() throws IOException { + overrideTaskType(null); + + RunJobParameters params = new RunJobParameters(); + params.setDistributed(true); + params.setName("testName"); + params.setInputPath("hdfs://bla"); + params.setEnvars(ImmutableList.of( + DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome", + DOCKER_JAVA_HOME + "=" + "testJavaHome")); + setLaunchCommandToParams(params); + + expectedException.expect(NullPointerException.class); + expectedException.expectMessage("TaskType must not be null"); + testHdfsRelatedEnvironmentIsDefined(taskType, params); + } + + @Test + public void testDistributedTrainingNumberOfWorkersAndPsIsZero() + throws IOException { + RunJobParameters params = new RunJobParameters(); + params.setDistributed(true); + params.setNumWorkers(0); + params.setNumPS(0); + params.setName("testName"); + params.setInputPath("hdfs://bla"); + params.setEnvars(ImmutableList.of( + DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome", + DOCKER_JAVA_HOME + "=" + "testJavaHome")); + setLaunchCommandToParams(params); + + List fileContents = + testHdfsRelatedEnvironmentIsDefined(taskType, params); + + assertScriptDoesNotContainLine(fileContents, "export TF_CONFIG="); + assertScriptContainsLineWithRegex(fileContents, ".*worker.*:\\[\\].*"); + assertScriptContainsLineWithRegex(fileContents, ".*ps.*:\\[\\].*"); + assertTypeInJson(fileContents); + } + + @Test + public void testDistributedTrainingNumberOfWorkersAndPsIsNonZero() + throws IOException { + RunJobParameters params = new RunJobParameters(); + params.setDistributed(true); + params.setNumWorkers(3); + params.setNumPS(2); + params.setName("testName"); + params.setInputPath("hdfs://bla"); + params.setEnvars(ImmutableList.of( + DOCKER_HADOOP_HDFS_HOME + "=" + "testHdfsHome", + DOCKER_JAVA_HOME + "=" + "testJavaHome")); + setLaunchCommandToParams(params); + + List fileContents = + testHdfsRelatedEnvironmentIsDefined(taskType, params); + + //assert we have multiple PS and workers + assertScriptDoesNotContainLine(fileContents, "export TF_CONFIG="); + assertScriptContainsLineWithRegex(fileContents, ".*worker.*:\\[.*,.*\\].*"); + assertScriptContainsLineWithRegex(fileContents, ".*ps.*:\\[.*,.*\\].*"); + assertTypeInJson(fileContents); + } + + +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/ComponentTestCommons.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/ComponentTestCommons.java new file mode 100644 index 00000000000..420fe5a4509 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/ComponentTestCommons.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.service.api.ServiceApiConstants; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.submarine.common.Envs; +import org.apache.hadoop.yarn.submarine.common.MockClientContext; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * This class has some helper methods and fields + * in order to test TensorFlow-related Components easier. + */ +public class ComponentTestCommons { + String userName; + TaskType taskType; + LaunchCommandFactory mockLaunchCommandFactory; + FileSystemOperations fsOperations; + MockClientContext mockClientContext; + Configuration yarnConfig; + Resource resource; + + ComponentTestCommons(TaskType taskType) { + this.taskType = taskType; + } + + public void setup() throws IOException { + this.userName = System.getProperty("user.name"); + this.resource = Resource.newInstance(4000, 10); + setupDependencies(); + } + + private void setupDependencies() throws IOException { + fsOperations = mock(FileSystemOperations.class); + mockClientContext = new MockClientContext(); + mockLaunchCommandFactory = mock(LaunchCommandFactory.class); + + AbstractLaunchCommand mockLaunchCommand = mock(AbstractLaunchCommand.class); + when(mockLaunchCommand.generateLaunchScript()).thenReturn("mockScript"); + when(mockLaunchCommandFactory.createLaunchCommand(eq(taskType), + any(Component.class))).thenReturn(mockLaunchCommand); + + yarnConfig = new Configuration(); + } + + void verifyCommonConfigEnvs(Component component) { + assertNotNull(component.getConfiguration().getEnv()); + assertEquals(2, component.getConfiguration().getEnv().size()); + assertEquals(ServiceApiConstants.COMPONENT_ID, + component.getConfiguration().getEnv().get(Envs.TASK_INDEX_ENV)); + assertEquals(taskType.name(), + component.getConfiguration().getEnv().get(Envs.TASK_TYPE_ENV)); + } + + void verifyResources(Component component) { + assertNotNull(component.getResource()); + assertEquals(10, (int) component.getResource().getCpus()); + assertEquals(4000, + (int) Integer.valueOf(component.getResource().getMemory())); + } +} diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorBoardComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorBoardComponent.java new file mode 100644 index 00000000000..1c81eb7cb34 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorBoardComponent.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.service.api.records.Artifact; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; + +/** + * This class is to test {@link TensorBoardComponent}. + */ +public class TestTensorBoardComponent { + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + private ComponentTestCommons testCommons = + new ComponentTestCommons(TaskType.TENSORBOARD); + + @Before + public void setUp() throws IOException { + testCommons.setup(); + } + + private TensorBoardComponent createTensorBoardComponent( + RunJobParameters parameters) { + return new TensorBoardComponent( + testCommons.fsOperations, + testCommons.mockClientContext.getRemoteDirectoryManager(), + parameters, + testCommons.mockLaunchCommandFactory, + testCommons.yarnConfig); + } + + @Test + public void testTensorBoardComponentWithNullResource() throws IOException { + RunJobParameters parameters = new RunJobParameters(); + parameters.setTensorboardResource(null); + + TensorBoardComponent tensorBoardComponent = + createTensorBoardComponent(parameters); + + expectedException.expect(NullPointerException.class); + expectedException.expectMessage("TensorBoard resource must not be null"); + tensorBoardComponent.createComponent(); + } + + @Test + public void testTensorBoardComponentWithNullJobName() throws IOException { + RunJobParameters parameters = new RunJobParameters(); + parameters.setTensorboardResource(testCommons.resource); + parameters.setName(null); + + TensorBoardComponent tensorBoardComponent = + createTensorBoardComponent(parameters); + + expectedException.expect(NullPointerException.class); + expectedException.expectMessage("Job name must not be null"); + tensorBoardComponent.createComponent(); + } + + @Test + public void testTensorBoardComponent() throws IOException { + testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain"); + + RunJobParameters parameters = new RunJobParameters(); + parameters.setTensorboardResource(testCommons.resource); + parameters.setName("testJobName"); + parameters.setTensorboardDockerImage("testTBDockerImage"); + + TensorBoardComponent tensorBoardComponent = + createTensorBoardComponent(parameters); + + Component component = tensorBoardComponent.createComponent(); + + assertEquals(testCommons.taskType.getComponentName(), component.getName()); + testCommons.verifyCommonConfigEnvs(component); + + assertEquals(1L, (long) component.getNumberOfContainers()); + assertEquals(RestartPolicyEnum.NEVER, component.getRestartPolicy()); + testCommons.verifyResources(component); + assertEquals( + new Artifact().type(Artifact.TypeEnum.DOCKER).id("testTBDockerImage"), + component.getArtifact()); + + assertEquals(String.format( + "http://tensorboard-0.testJobName.%s" + ".testDomain:6006", + testCommons.userName), + tensorBoardComponent.getTensorboardLink()); + + assertEquals("./run-TENSORBOARD.sh", component.getLaunchCommand()); + verify(testCommons.fsOperations) + .uploadToRemoteFileAndLocalizeToContainerWorkDir( + any(Path.class), eq("mockScript"), eq("run-TENSORBOARD.sh"), + eq(component)); + } + +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowPsComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowPsComponent.java new file mode 100644 index 00000000000..8027365a226 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowPsComponent.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import java.io.IOException; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.service.api.records.Artifact; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** + * This class is to test {@link TensorFlowPsComponent}. + */ +public class TestTensorFlowPsComponent { + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + private ComponentTestCommons testCommons = + new ComponentTestCommons(TaskType.PS); + + @Before + public void setUp() throws IOException { + testCommons.setup(); + } + + private TensorFlowPsComponent createPsComponent(RunJobParameters parameters) { + return new TensorFlowPsComponent( + testCommons.fsOperations, + testCommons.mockClientContext.getRemoteDirectoryManager(), + testCommons.mockLaunchCommandFactory, + parameters, + testCommons.yarnConfig); + } + + private void verifyCommons(Component component) throws IOException { + assertEquals(testCommons.taskType.getComponentName(), component.getName()); + testCommons.verifyCommonConfigEnvs(component); + + assertTrue(component.getConfiguration().getProperties().isEmpty()); + + assertEquals(RestartPolicyEnum.NEVER, component.getRestartPolicy()); + testCommons.verifyResources(component); + assertEquals( + new Artifact().type(Artifact.TypeEnum.DOCKER).id("testPSDockerImage"), + component.getArtifact()); + + String taskTypeUppercase = testCommons.taskType.name().toUpperCase(); + String expectedScriptName = String.format("run-%s.sh", taskTypeUppercase); + assertEquals(String.format("./%s", expectedScriptName), + component.getLaunchCommand()); + verify(testCommons.fsOperations) + .uploadToRemoteFileAndLocalizeToContainerWorkDir( + any(Path.class), eq("mockScript"), eq(expectedScriptName), + eq(component)); + } + + @Test + public void testPSComponentWithNullResource() throws IOException { + RunJobParameters parameters = new RunJobParameters(); + parameters.setPsResource(null); + + TensorFlowPsComponent psComponent = + createPsComponent(parameters); + + expectedException.expect(NullPointerException.class); + expectedException.expectMessage("PS resource must not be null"); + psComponent.createComponent(); + } + + @Test + public void testPSComponentWithNullJobName() throws IOException { + RunJobParameters parameters = new RunJobParameters(); + parameters.setPsResource(testCommons.resource); + parameters.setNumPS(1); + parameters.setName(null); + + TensorFlowPsComponent psComponent = + createPsComponent(parameters); + + expectedException.expect(NullPointerException.class); + expectedException.expectMessage("Job name must not be null"); + psComponent.createComponent(); + } + + @Test + public void testPSComponentZeroNumberOfPS() throws IOException { + testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain"); + + RunJobParameters parameters = new RunJobParameters(); + parameters.setPsResource(testCommons.resource); + parameters.setName("testJobName"); + parameters.setPsDockerImage("testPSDockerImage"); + parameters.setNumPS(0); + + TensorFlowPsComponent psComponent = + createPsComponent(parameters); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Number of PS should be at least 1!"); + psComponent.createComponent(); + } + + @Test + public void testPSComponentNumPSIsOne() throws IOException { + testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain"); + + RunJobParameters parameters = new RunJobParameters(); + parameters.setPsResource(testCommons.resource); + parameters.setName("testJobName"); + parameters.setNumPS(1); + parameters.setPsDockerImage("testPSDockerImage"); + + TensorFlowPsComponent psComponent = + createPsComponent(parameters); + + Component component = psComponent.createComponent(); + + assertEquals(1L, (long) component.getNumberOfContainers()); + verifyCommons(component); + } + + @Test + public void testPSComponentNumPSIsTwo() throws IOException { + testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain"); + + RunJobParameters parameters = new RunJobParameters(); + parameters.setPsResource(testCommons.resource); + parameters.setName("testJobName"); + parameters.setNumPS(2); + parameters.setPsDockerImage("testPSDockerImage"); + + TensorFlowPsComponent psComponent = + createPsComponent(parameters); + + Component component = psComponent.createComponent(); + + assertEquals(2L, (long) component.getNumberOfContainers()); + verifyCommons(component); + } + +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowWorkerComponent.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowWorkerComponent.java new file mode 100644 index 00000000000..24bebc2d6e4 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/runtimes/yarnservice/tensorflow/component/TestTensorFlowWorkerComponent.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component; + +import com.google.common.collect.ImmutableMap; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.service.api.records.Artifact; +import org.apache.hadoop.yarn.service.api.records.Component; +import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.api.TaskType; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.io.IOException; +import java.util.Map; + +import static junit.framework.TestCase.assertTrue; +import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; + +/** + * This class is to test {@link TensorFlowWorkerComponent}. + */ +public class TestTensorFlowWorkerComponent { + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + private ComponentTestCommons testCommons = + new ComponentTestCommons(TaskType.TENSORBOARD); + + @Before + public void setUp() throws IOException { + testCommons.setup(); + } + + private TensorFlowWorkerComponent createWorkerComponent( + RunJobParameters parameters) { + return new TensorFlowWorkerComponent( + testCommons.fsOperations, + testCommons.mockClientContext.getRemoteDirectoryManager(), + parameters, testCommons.taskType, + testCommons.mockLaunchCommandFactory, + testCommons.yarnConfig); + } + + private void verifyCommons(Component component) throws IOException { + verifyCommonsInternal(component, ImmutableMap.of()); + } + + private void verifyCommons(Component component, + Map expectedProperties) throws IOException { + verifyCommonsInternal(component, expectedProperties); + } + + private void verifyCommonsInternal(Component component, + Map expectedProperties) throws IOException { + assertEquals(testCommons.taskType.getComponentName(), component.getName()); + testCommons.verifyCommonConfigEnvs(component); + + Map actualProperties = + component.getConfiguration().getProperties(); + if (!expectedProperties.isEmpty()) { + assertFalse(actualProperties.isEmpty()); + expectedProperties.forEach( + (k, v) -> assertEquals(v, actualProperties.get(k))); + } else { + assertTrue(actualProperties.isEmpty()); + } + + assertEquals(RestartPolicyEnum.NEVER, component.getRestartPolicy()); + testCommons.verifyResources(component); + assertEquals( + new Artifact().type(Artifact.TypeEnum.DOCKER) + .id("testWorkerDockerImage"), + component.getArtifact()); + + String taskTypeUppercase = testCommons.taskType.name().toUpperCase(); + String expectedScriptName = String.format("run-%s.sh", taskTypeUppercase); + assertEquals(String.format("./%s", expectedScriptName), + component.getLaunchCommand()); + verify(testCommons.fsOperations) + .uploadToRemoteFileAndLocalizeToContainerWorkDir( + any(Path.class), eq("mockScript"), eq(expectedScriptName), + eq(component)); + } + + @Test + public void testWorkerComponentWithNullResource() throws IOException { + RunJobParameters parameters = new RunJobParameters(); + parameters.setWorkerResource(null); + + TensorFlowWorkerComponent workerComponent = + createWorkerComponent(parameters); + + expectedException.expect(NullPointerException.class); + expectedException.expectMessage("Worker resource must not be null"); + workerComponent.createComponent(); + } + + @Test + public void testWorkerComponentWithNullJobName() throws IOException { + RunJobParameters parameters = new RunJobParameters(); + parameters.setWorkerResource(testCommons.resource); + parameters.setNumWorkers(1); + parameters.setName(null); + + TensorFlowWorkerComponent workerComponent = + createWorkerComponent(parameters); + + expectedException.expect(NullPointerException.class); + expectedException.expectMessage("Job name must not be null"); + workerComponent.createComponent(); + } + + @Test + public void testNormalWorkerComponentZeroNumberOfWorkers() + throws IOException { + testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain"); + + RunJobParameters parameters = new RunJobParameters(); + parameters.setWorkerResource(testCommons.resource); + parameters.setName("testJobName"); + parameters.setWorkerDockerImage("testWorkerDockerImage"); + parameters.setNumWorkers(0); + + TensorFlowWorkerComponent workerComponent = + createWorkerComponent(parameters); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Number of workers should be at least 1!"); + workerComponent.createComponent(); + } + + @Test + public void testNormalWorkerComponentNumWorkersIsOne() throws IOException { + testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain"); + + RunJobParameters parameters = new RunJobParameters(); + parameters.setWorkerResource(testCommons.resource); + parameters.setName("testJobName"); + parameters.setNumWorkers(1); + parameters.setWorkerDockerImage("testWorkerDockerImage"); + + TensorFlowWorkerComponent workerComponent = + createWorkerComponent(parameters); + + Component component = workerComponent.createComponent(); + + assertEquals(0L, (long) component.getNumberOfContainers()); + verifyCommons(component); + } + + @Test + public void testNormalWorkerComponentNumWorkersIsTwo() throws IOException { + testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain"); + + RunJobParameters parameters = new RunJobParameters(); + parameters.setWorkerResource(testCommons.resource); + parameters.setName("testJobName"); + parameters.setNumWorkers(2); + parameters.setWorkerDockerImage("testWorkerDockerImage"); + + TensorFlowWorkerComponent workerComponent = + createWorkerComponent(parameters); + + Component component = workerComponent.createComponent(); + + assertEquals(1L, (long) component.getNumberOfContainers()); + verifyCommons(component); + } + + @Test + public void testPrimaryWorkerComponentNumWorkersIsTwo() throws IOException { + testCommons.yarnConfig.set("hadoop.registry.dns.domain-name", "testDomain"); + testCommons = new ComponentTestCommons(TaskType.PRIMARY_WORKER); + testCommons.setup(); + + RunJobParameters parameters = new RunJobParameters(); + parameters.setWorkerResource(testCommons.resource); + parameters.setName("testJobName"); + parameters.setNumWorkers(2); + parameters.setWorkerDockerImage("testWorkerDockerImage"); + + TensorFlowWorkerComponent workerComponent = + createWorkerComponent(parameters); + + Component component = workerComponent.createComponent(); + + assertEquals(1L, (long) component.getNumberOfContainers()); + verifyCommons(component, ImmutableMap.of( + CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true")); + } + +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestClassPathUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestClassPathUtilities.java new file mode 100644 index 00000000000..8fdb4753e83 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestClassPathUtilities.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import org.apache.hadoop.yarn.submarine.FileUtilitiesForTests; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +/** + * This class is to test {@link ClassPathUtilities}. + */ +public class TestClassPathUtilities { + + private static final String CLASSPATH_KEY = "java.class.path"; + private FileUtilitiesForTests fileUtils = new FileUtilitiesForTests(); + private static String originalClasspath; + + @BeforeClass + public static void setUpClass() { + originalClasspath = System.getProperty(CLASSPATH_KEY); + } + + @Before + public void setUp() { + fileUtils.setup(); + } + + @After + public void teardown() throws IOException { + fileUtils.teardown(); + System.setProperty(CLASSPATH_KEY, originalClasspath); + } + + private static void addFileToClasspath(File file) { + String newClasspath = originalClasspath + ":" + file.getAbsolutePath(); + System.setProperty(CLASSPATH_KEY, newClasspath); + } + + @Test + public void findFileNotInClasspath() { + File resultFile = ClassPathUtilities.findFileOnClassPath("bla"); + assertNull(resultFile); + } + + @Test + public void findFileOnClasspath() throws Exception { + File testFile = fileUtils.createFileInTempDir("testFile"); + + addFileToClasspath(testFile); + File resultFile = ClassPathUtilities.findFileOnClassPath("testFile"); + + assertNotNull(resultFile); + assertEquals(testFile.getAbsolutePath(), resultFile.getAbsolutePath()); + } + + @Test + public void findDirectoryOnClasspath() throws Exception { + File testDir = fileUtils.createDirInTempDir("testDir"); + File testFile = fileUtils.createFileInDir(testDir, "testFile"); + + addFileToClasspath(testDir); + File resultFile = ClassPathUtilities.findFileOnClassPath("testFile"); + + assertNotNull(resultFile); + assertEquals(testFile.getAbsolutePath(), resultFile.getAbsolutePath()); + } + +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestEnvironmentUtilities.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestEnvironmentUtilities.java new file mode 100644 index 00000000000..a52c1cfe89e --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestEnvironmentUtilities.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import org.apache.hadoop.yarn.service.api.records.Configuration; +import org.apache.hadoop.yarn.service.api.records.Service; +import org.junit.Test; + +import java.util.Map; + +import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION; +import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME; +import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * This class is to test {@link EnvironmentUtilities}. + */ +public class TestEnvironmentUtilities { + private Service createServiceWithEmptyEnvVars() { + return createServiceWithEnvVars(Maps.newHashMap()); + } + + private Service createServiceWithEnvVars(Map envVars) { + Service service = mock(Service.class); + Configuration config = mock(Configuration.class); + when(config.getEnv()).thenReturn(envVars); + when(service.getConfiguration()).thenReturn(config); + + return service; + } + + private void validateDefaultEnvVars(Map resultEnvs) { + assertEquals("/etc/passwd:/etc/passwd:ro", + resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME)); + } + + private org.apache.hadoop.conf.Configuration + createYarnConfigWithSecurityValue(String value) { + org.apache.hadoop.conf.Configuration mockConfig = + mock(org.apache.hadoop.conf.Configuration.class); + when(mockConfig.get(HADOOP_SECURITY_AUTHENTICATION)).thenReturn(value); + return mockConfig; + } + + @Test + public void testGetValueOfNullEnvVar() { + assertEquals("", EnvironmentUtilities.getValueOfEnvironment(null)); + } + + @Test + public void testGetValueOfEmptyEnvVar() { + assertEquals("", EnvironmentUtilities.getValueOfEnvironment("")); + } + + @Test + public void testGetValueOfEnvVarJustAnEqualsSign() { + assertEquals("", EnvironmentUtilities.getValueOfEnvironment("=")); + } + + @Test + public void testGetValueOfEnvVarWithoutValue() { + assertEquals("", EnvironmentUtilities.getValueOfEnvironment("a=")); + } + + @Test + public void testGetValueOfEnvVarValidFormat() { + assertEquals("bbb", EnvironmentUtilities.getValueOfEnvironment("a=bbb")); + } + + @Test + public void testHandleServiceEnvWithNullMap() { + Service service = createServiceWithEmptyEnvVars(); + org.apache.hadoop.conf.Configuration yarnConfig = + mock(org.apache.hadoop.conf.Configuration.class); + EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null); + + Map resultEnvs = service.getConfiguration().getEnv(); + assertEquals(1, resultEnvs.size()); + validateDefaultEnvVars(resultEnvs); + } + + @Test + public void testHandleServiceEnvWithEmptyMap() { + Service service = createServiceWithEmptyEnvVars(); + org.apache.hadoop.conf.Configuration yarnConfig = + mock(org.apache.hadoop.conf.Configuration.class); + EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null); + + Map resultEnvs = service.getConfiguration().getEnv(); + assertEquals(1, resultEnvs.size()); + validateDefaultEnvVars(resultEnvs); + } + + @Test + public void testHandleServiceEnvWithYarnConfigSecurityValueNonKerberos() { + Service service = createServiceWithEmptyEnvVars(); + org.apache.hadoop.conf.Configuration yarnConfig = + createYarnConfigWithSecurityValue("nonkerberos"); + EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null); + + Map resultEnvs = service.getConfiguration().getEnv(); + assertEquals(1, resultEnvs.size()); + validateDefaultEnvVars(resultEnvs); + } + + @Test + public void testHandleServiceEnvWithYarnConfigSecurityValueKerberos() { + Service service = createServiceWithEmptyEnvVars(); + org.apache.hadoop.conf.Configuration yarnConfig = + createYarnConfigWithSecurityValue("kerberos"); + EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, null); + + Map resultEnvs = service.getConfiguration().getEnv(); + assertEquals(1, resultEnvs.size()); + assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro", + resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME)); + } + + @Test + public void testHandleServiceEnvWithExistingEnvsAndValidNewEnvs() { + Map existingEnvs = Maps.newHashMap( + ImmutableMap.builder(). + put("a", "1"). + put("b", "2"). + build()); + ImmutableList newEnvs = ImmutableList.of("c=3", "d=4"); + + Service service = createServiceWithEnvVars(existingEnvs); + org.apache.hadoop.conf.Configuration yarnConfig = + createYarnConfigWithSecurityValue("kerberos"); + EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs); + + Map resultEnvs = service.getConfiguration().getEnv(); + assertEquals(5, resultEnvs.size()); + assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro", + resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME)); + assertEquals("1", resultEnvs.get("a")); + assertEquals("2", resultEnvs.get("b")); + assertEquals("3", resultEnvs.get("c")); + assertEquals("4", resultEnvs.get("d")); + } + + @Test + public void testHandleServiceEnvWithExistingEnvsAndNewEnvsWithoutEquals() { + Map existingEnvs = Maps.newHashMap( + ImmutableMap.builder(). + put("a", "1"). + put("b", "2"). + build()); + ImmutableList newEnvs = ImmutableList.of("c3", "d4"); + + Service service = createServiceWithEnvVars(existingEnvs); + org.apache.hadoop.conf.Configuration yarnConfig = + createYarnConfigWithSecurityValue("kerberos"); + EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs); + + Map resultEnvs = service.getConfiguration().getEnv(); + assertEquals(5, resultEnvs.size()); + assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro", + resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME)); + assertEquals("1", resultEnvs.get("a")); + assertEquals("2", resultEnvs.get("b")); + assertEquals("", resultEnvs.get("c3")); + assertEquals("", resultEnvs.get("d4")); + } + + @Test + public void testHandleServiceEnvWithExistingEnvVarKey() { + Map existingEnvs = Maps.newHashMap( + ImmutableMap.builder(). + put("a", "1"). + put("b", "2"). + build()); + ImmutableList newEnvs = ImmutableList.of("a=33", "c=44"); + + Service service = createServiceWithEnvVars(existingEnvs); + org.apache.hadoop.conf.Configuration yarnConfig = + createYarnConfigWithSecurityValue("kerberos"); + EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs); + + Map resultEnvs = service.getConfiguration().getEnv(); + assertEquals(4, resultEnvs.size()); + assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro", + resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME)); + assertEquals("1:33", resultEnvs.get("a")); + assertEquals("2", resultEnvs.get("b")); + assertEquals("44", resultEnvs.get("c")); + } + + @Test + public void testHandleServiceEnvWithExistingEnvVarKeyMultipleTimes() { + Map existingEnvs = Maps.newHashMap( + ImmutableMap.builder(). + put("a", "1"). + put("b", "2"). + build()); + ImmutableList newEnvs = ImmutableList.of("a=33", "a=44"); + + Service service = createServiceWithEnvVars(existingEnvs); + org.apache.hadoop.conf.Configuration yarnConfig = + createYarnConfigWithSecurityValue("kerberos"); + EnvironmentUtilities.handleServiceEnvs(service, yarnConfig, newEnvs); + + Map resultEnvs = service.getConfiguration().getEnv(); + assertEquals(3, resultEnvs.size()); + assertEquals("/etc/passwd:/etc/passwd:ro,/etc/krb5.conf:/etc/krb5.conf:ro", + resultEnvs.get(ENV_DOCKER_MOUNTS_FOR_CONTAINER_RUNTIME)); + assertEquals("1:33:44", resultEnvs.get("a")); + assertEquals("2", resultEnvs.get("b")); + } + +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestKerberosPrincipalFactory.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestKerberosPrincipalFactory.java new file mode 100644 index 00000000000..74cbc853d0f --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestKerberosPrincipalFactory.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal; +import org.apache.hadoop.yarn.submarine.FileUtilitiesForTests; +import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; +import org.apache.hadoop.yarn.submarine.common.MockClientContext; +import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * This class is to test {@link KerberosPrincipalFactory}. + */ +public class TestKerberosPrincipalFactory { + private FileUtilitiesForTests fileUtils = new FileUtilitiesForTests(); + + @Before + public void setUp() { + fileUtils.setup(); + } + + @After + public void teardown() throws IOException { + fileUtils.teardown(); + } + + private File createKeytabFile(String keytabFileName) throws IOException { + return fileUtils.createFileInTempDir(keytabFileName); + } + + @Test + public void testCreatePrincipalEmptyPrincipalAndKeytab() throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + + RunJobParameters parameters = mock(RunJobParameters.class); + when(parameters.getPrincipal()).thenReturn(""); + when(parameters.getKeytab()).thenReturn(""); + + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + KerberosPrincipal result = + KerberosPrincipalFactory.create(fsOperations, + mockClientContext.getRemoteDirectoryManager(), parameters); + + assertNull(result); + } + @Test + public void testCreatePrincipalEmptyPrincipalString() throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + + RunJobParameters parameters = mock(RunJobParameters.class); + when(parameters.getPrincipal()).thenReturn(""); + when(parameters.getKeytab()).thenReturn("keytab"); + + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + KerberosPrincipal result = + KerberosPrincipalFactory.create(fsOperations, + mockClientContext.getRemoteDirectoryManager(), parameters); + + assertNull(result); + } + + @Test + public void testCreatePrincipalEmptyKeyTabString() throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + + RunJobParameters parameters = mock(RunJobParameters.class); + when(parameters.getPrincipal()).thenReturn("principal"); + when(parameters.getKeytab()).thenReturn(""); + + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + KerberosPrincipal result = + KerberosPrincipalFactory.create(fsOperations, + mockClientContext.getRemoteDirectoryManager(), parameters); + + assertNull(result); + } + + @Test + public void testCreatePrincipalNonEmptyPrincipalAndKeytab() + throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + + RunJobParameters parameters = mock(RunJobParameters.class); + when(parameters.getPrincipal()).thenReturn("principal"); + when(parameters.getKeytab()).thenReturn("keytab"); + + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + KerberosPrincipal result = + KerberosPrincipalFactory.create(fsOperations, + mockClientContext.getRemoteDirectoryManager(), parameters); + + assertNotNull(result); + assertEquals("file://keytab", result.getKeytab()); + assertEquals("principal", result.getPrincipalName()); + } + + @Test + public void testCreatePrincipalDistributedKeytab() throws IOException { + MockClientContext mockClientContext = new MockClientContext(); + String jobname = "testJobname"; + String keytab = "testKeytab"; + File keytabFile = createKeytabFile(keytab); + + RunJobParameters parameters = mock(RunJobParameters.class); + when(parameters.getPrincipal()).thenReturn("principal"); + when(parameters.getKeytab()).thenReturn(keytabFile.getAbsolutePath()); + when(parameters.getName()).thenReturn(jobname); + when(parameters.isDistributeKeytab()).thenReturn(true); + + FileSystemOperations fsOperations = + new FileSystemOperations(mockClientContext); + + KerberosPrincipal result = + KerberosPrincipalFactory.create(fsOperations, + mockClientContext.getRemoteDirectoryManager(), parameters); + + Path stagingDir = mockClientContext.getRemoteDirectoryManager() + .getJobStagingArea(parameters.getName(), true); + String expectedKeytabFilePath = + FileUtilitiesForTests.getFilename(stagingDir, keytab).getAbsolutePath(); + + assertNotNull(result); + assertEquals(expectedKeytabFilePath, result.getKeytab()); + assertEquals("principal", result.getPrincipalName()); + } + +} \ No newline at end of file diff --git a/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestSubmarineResourceUtils.java b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestSubmarineResourceUtils.java new file mode 100644 index 00000000000..f22fbaaa2b5 --- /dev/null +++ b/hadoop-submarine/hadoop-submarine-yarnservice-runtime/src/test/java/org/apache/hadoop/yarn/submarine/utils/TestSubmarineResourceUtils.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.submarine.utils; + +import com.google.common.collect.ImmutableMap; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.service.api.records.ResourceInformation; +import org.apache.hadoop.yarn.util.resource.CustomResourceTypesConfigurationProvider; +import org.apache.hadoop.yarn.util.resource.ResourceUtils; +import org.junit.After; +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.*; + +/** + * This class is to test {@link SubmarineResourceUtils}. + */ +public class TestSubmarineResourceUtils { + private static final String CUSTOM_RESOURCE_NAME = "a-custom-resource"; + + private void initResourceTypes() { + CustomResourceTypesConfigurationProvider.initResourceTypes( + ImmutableMap.builder() + .put(CUSTOM_RESOURCE_NAME, "G") + .build()); + } + + @After + public void cleanup() { + ResourceUtils.resetResourceTypes(new Configuration()); + } + + @Test + public void testConvertResourceWithCustomResource() { + initResourceTypes(); + Resource res = Resource.newInstance(4096, 12, + ImmutableMap.of(CUSTOM_RESOURCE_NAME, 20L)); + + org.apache.hadoop.yarn.service.api.records.Resource serviceResource = + SubmarineResourceUtils.convertYarnResourceToServiceResource(res); + + assertEquals(12, serviceResource.getCpus().intValue()); + assertEquals(4096, (int) Integer.valueOf(serviceResource.getMemory())); + Map additionalResources = + serviceResource.getAdditional(); + + // Additional resources also includes vcores and memory + assertEquals(3, additionalResources.size()); + ResourceInformation customResourceRI = + additionalResources.get(CUSTOM_RESOURCE_NAME); + assertEquals("G", customResourceRI.getUnit()); + assertEquals(20L, (long) customResourceRI.getValue()); + } + +} \ No newline at end of file