SUBMARINE-52. [SUBMARINE-14] Generate Service spec + launch script for single-node PyTorch learning job. Contributed by Szilard Nemeth.

This commit is contained in:
Zhankun Tang 2019-05-10 23:40:17 +08:00
parent 64c7f36ab1
commit 36267b6f7c
118 changed files with 4695 additions and 1209 deletions

View File

@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
ARG PYTHON_VERSION=3.6
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
curl \
vim \
ca-certificates \
libjpeg-dev \
libpng-dev \
wget &&\
rm -rf /var/lib/apt/lists/*
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include cython typing && \
/opt/conda/bin/conda install -y -c pytorch magma-cuda100 && \
/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/bin:$PATH
RUN pip install ninja
# This must be done before pip so that requirements.txt is available
WORKDIR /opt/pytorch
RUN git clone https://github.com/pytorch/pytorch.git
WORKDIR pytorch
RUN git submodule update --init
RUN TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1 7.0+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
pip install -v .
WORKDIR /opt/pytorch
RUN git clone https://github.com/pytorch/vision.git && cd vision && pip install -v .
WORKDIR /
# Install Hadoop
ENV HADOOP_VERSION="3.1.2"
RUN wget https://archive.apache.org/dist/hadoop/common/hadoop-${HADOOP_VERSION}/hadoop-${HADOOP_VERSION}.tar.gz
RUN tar zxf hadoop-${HADOOP_VERSION}.tar.gz
RUN ln -s hadoop-${HADOOP_VERSION} hadoop-current
RUN rm hadoop-${HADOOP_VERSION}.tar.gz
ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64
RUN echo "$LOG_TAG Install java8" && \
apt-get update && \
apt-get install -y --no-install-recommends openjdk-8-jdk && \
apt-get clean && rm -rf /var/lib/apt/lists/*
RUN echo "Install python related packages" && \
pip --no-cache-dir install Pillow h5py ipykernel jupyter matplotlib numpy pandas scipy sklearn && \
python -m ipykernel.kernelspec
# Set the locale to fix bash warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)
RUN apt-get update && apt-get install -y --no-install-recommends locales && \
apt-get clean && rm -rf /var/lib/apt/lists/*
RUN locale-gen en_US.UTF-8
WORKDIR /workspace
RUN chmod -R a+w /workspace

View File

@ -0,0 +1,30 @@
#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
echo "Building base images"
set -e
cd base/ubuntu-16.04
docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu-base:0.0.1
echo "Finished building base images"
cd ../../with-cifar10-models/ubuntu-16.04
docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu:0.0.1

View File

@ -0,0 +1,354 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -*- coding: utf-8 -*-
"""
Training a Classifier
=====================
This is it. You have seen how to define neural networks, compute loss and make
updates to the weights of the network.
Now you might be thinking,
What about data?
----------------
Generally, when you have to deal with image, text, audio or video data,
you can use standard python packages that load data into a numpy array.
Then you can convert this array into a ``torch.*Tensor``.
- For images, packages such as Pillow, OpenCV are useful
- For audio, packages such as scipy and librosa
- For text, either raw Python or Cython based loading, or NLTK and
SpaCy are useful
Specifically for vision, we have created a package called
``torchvision``, that has data loaders for common datasets such as
Imagenet, CIFAR10, MNIST, etc. and data transformers for images, viz.,
``torchvision.datasets`` and ``torch.utils.data.DataLoader``.
This provides a huge convenience and avoids writing boilerplate code.
For this tutorial, we will use the CIFAR10 dataset.
It has the classes: airplane, automobile, bird, cat, deer,
dog, frog, horse, ship, truck. The images in CIFAR-10 are of
size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.
.. figure:: /_static/img/cifar10.png
:alt: cifar10
cifar10
Training an image classifier
----------------------------
We will do the following steps in order:
1. Load and normalizing the CIFAR10 training and test datasets using
``torchvision``
2. Define a Convolutional Neural Network
3. Define a loss function
4. Train the network on the training data
5. Test the network on the test data
1. Loading and normalizing CIFAR10
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Using ``torchvision``, its extremely easy to load CIFAR10.
"""
import torch
import torchvision
import torchvision.transforms as transforms
########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
########################################################################
# Let us show some of the training images, for fun.
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
########################################################################
# 2. Define a Convolutional Neural Network
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Copy the neural network from the Neural Networks section before and modify it to
# take 3-channel images (instead of 1-channel images as it was defined).
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
########################################################################
# 3. Define a Loss function and optimizer
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Let's use a Classification Cross-Entropy loss and SGD with momentum.
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
########################################################################
# 4. Train the network
# ^^^^^^^^^^^^^^^^^^^^
#
# This is when things start to get interesting.
# We simply have to loop over our data iterator, and feed the inputs to the
# network and optimize.
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
########################################################################
# 5. Test the network on the test data
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We have trained the network for 2 passes over the training dataset.
# But we need to check if the network has learnt anything at all.
#
# We will check this by predicting the class label that the neural network
# outputs, and checking it against the ground-truth. If the prediction is
# correct, we add the sample to the list of correct predictions.
#
# Okay, first step. Let us display an image from the test set to get familiar.
dataiter = iter(testloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
########################################################################
# Okay, now let us see what the neural network thinks these examples above are:
outputs = net(images)
########################################################################
# The outputs are energies for the 10 classes.
# The higher the energy for a class, the more the network
# thinks that the image is of the particular class.
# So, let's get the index of the highest energy:
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
for j in range(4)))
########################################################################
# The results seem pretty good.
#
# Let us look at how the network performs on the whole dataset.
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
########################################################################
# That looks waaay better than chance, which is 10% accuracy (randomly picking
# a class out of 10 classes).
# Seems like the network learnt something.
#
# Hmmm, what are the classes that performed well, and the classes that did
# not perform well:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
########################################################################
# Okay, so what next?
#
# How do we run these neural networks on the GPU?
#
# Training on GPU
# ----------------
# Just like how you transfer a Tensor onto the GPU, you transfer the neural
# net onto the GPU.
#
# Let's first define our device as the first visible cuda device if we have
# CUDA available:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
########################################################################
# The rest of this section assumes that ``device`` is a CUDA device.
#
# Then these methods will recursively go over all modules and convert their
# parameters and buffers to CUDA tensors:
#
# .. code:: python
#
# net.to(device)
#
#
# Remember that you will have to send the inputs and targets at every step
# to the GPU too:
#
# .. code:: python
#
# inputs, labels = inputs.to(device), labels.to(device)
#
# Why dont I notice MASSIVE speedup compared to CPU? Because your network
# is realllly small.
#
# **Exercise:** Try increasing the width of your network (argument 2 of
# the first ``nn.Conv2d``, and argument 1 of the second ``nn.Conv2d``
# they need to be the same number), see what kind of speedup you get.
#
# **Goals achieved**:
#
# - Understanding PyTorch's Tensor library and neural networks at a high level.
# - Train a small neural network to classify images
#
# Training on multiple GPUs
# -------------------------
# If you want to see even more MASSIVE speedup using all of your GPUs,
# please check out :doc:`data_parallel_tutorial`.
#
# Where do I go next?
# -------------------
#
# - :doc:`Train neural nets to play video games </intermediate/reinforcement_q_learning>`
# - `Train a state-of-the-art ResNet network on imagenet`_
# - `Train a face generator using Generative Adversarial Networks`_
# - `Train a word-level language model using Recurrent LSTM networks`_
# - `More examples`_
# - `More tutorials`_
# - `Discuss PyTorch on the Forums`_
# - `Chat with other users on Slack`_
#
# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/master/imagenet
# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/master/dcgan
# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/master/word_language_model
# .. _More examples: https://github.com/pytorch/examples
# .. _More tutorials: https://github.com/pytorch/tutorials
# .. _Discuss PyTorch on the Forums: https://discuss.pytorch.org/
# .. _Chat with other users on Slack: https://pytorch.slack.com/messages/beginner/
# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%
del dataiter
# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%

View File

@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM pytorch-latest-gpu-base:0.0.1
RUN mkdir -p /test/data
RUN chmod -R 777 /test
ADD cifar10_tutorial.py /test/cifar10_tutorial.py

View File

@ -15,6 +15,7 @@
package org.apache.hadoop.yarn.submarine.client.cli;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;

View File

@ -59,4 +59,6 @@ public class CliConstants {
public static final String DISTRIBUTE_KEYTAB = "distribute_keytab";
public static final String YAML_CONFIG = "f";
public static final String INSECURE_CLUSTER = "insecure";
public static final String FRAMEWORK = "framework";
}

View File

@ -16,7 +16,7 @@ package org.apache.hadoop.yarn.submarine.client.cli;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.slf4j.Logger;

View File

@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli;
/**
* Represents a Submarine command.
*/
public enum Command {
RUN_JOB, SHOW_JOB
}

View File

@ -37,7 +37,7 @@ public class ShowJobCli extends AbstractCli {
private static final Logger LOG = LoggerFactory.getLogger(ShowJobCli.class);
private Options options;
private ShowJobParameters parameters = new ShowJobParameters();
private ParametersHolder parametersHolder;
public ShowJobCli(ClientContext cliContext) {
super(cliContext);
@ -62,9 +62,9 @@ public class ShowJobCli extends AbstractCli {
CommandLine cli;
try {
cli = parser.parse(options, args);
ParametersHolder parametersHolder = ParametersHolder
.createWithCmdLine(cli);
parameters.updateParameters(parametersHolder, clientContext);
parametersHolder = ParametersHolder
.createWithCmdLine(cli, Command.SHOW_JOB);
parametersHolder.updateParameters(clientContext);
} catch (ParseException e) {
printUsages();
}
@ -97,7 +97,7 @@ public class ShowJobCli extends AbstractCli {
Map<String, String> jobInfo = null;
try {
jobInfo = storage.getJobInfoByName(parameters.getName());
jobInfo = storage.getJobInfoByName(getParameters().getName());
} catch (IOException e) {
LOG.error("Failed to retrieve job info", e);
throw e;
@ -108,7 +108,7 @@ public class ShowJobCli extends AbstractCli {
@VisibleForTesting
public ShowJobParameters getParameters() {
return parameters;
return (ShowJobParameters) parametersHolder.getParameters();
}
@Override

View File

@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param;
/**
* Represents the source of configuration.
*/
public enum ConfigType {
YAML, CLI
}

View File

@ -20,8 +20,12 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.Command;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Roles;
@ -29,15 +33,22 @@ import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Scheduling;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Security;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.TensorBoard;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli.YAML_PARSE_FAILED;
/**
* This class acts as a wrapper of {@code CommandLine} values along with
* YAML configuration values.
@ -52,17 +63,110 @@ public final class ParametersHolder {
private static final Logger LOG =
LoggerFactory.getLogger(ParametersHolder.class);
public static final String SUPPORTED_FRAMEWORKS_MESSAGE =
"TensorFlow and PyTorch are the only supported frameworks for now!";
public static final String SUPPORTED_COMMANDS_MESSAGE =
"'Show job' and 'run job' are the only supported commands for now!";
private final CommandLine parsedCommandLine;
private final Map<String, String> yamlStringConfigs;
private final Map<String, List<String>> yamlListConfigs;
private final ImmutableSet onlyDefinedWithCliArgs = ImmutableSet.of(
private final ConfigType configType;
private Command command;
private final Set onlyDefinedWithCliArgs = ImmutableSet.of(
CliConstants.VERBOSE);
private final Framework framework;
private final BaseParameters parameters;
private ParametersHolder(CommandLine parsedCommandLine,
YamlConfigFile yamlConfig) {
YamlConfigFile yamlConfig, ConfigType configType, Command command)
throws ParseException, YarnException {
this.parsedCommandLine = parsedCommandLine;
this.yamlStringConfigs = initStringConfigValues(yamlConfig);
this.yamlListConfigs = initListConfigValues(yamlConfig);
this.configType = configType;
this.command = command;
this.framework = determineFrameworkType();
this.ensureOnlyValidSectionsAreDefined(yamlConfig);
this.parameters = createParameters();
}
private BaseParameters createParameters() {
if (command == Command.RUN_JOB) {
if (framework == Framework.TENSORFLOW) {
return new TensorFlowRunJobParameters();
} else if (framework == Framework.PYTORCH) {
return new PyTorchRunJobParameters();
} else {
throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
}
} else if (command == Command.SHOW_JOB) {
return new ShowJobParameters();
} else {
throw new UnsupportedOperationException(SUPPORTED_COMMANDS_MESSAGE);
}
}
private void ensureOnlyValidSectionsAreDefined(YamlConfigFile yamlConfig) {
if (isCommandRunJob() && isFrameworkPyTorch() &&
isPsSectionDefined(yamlConfig)) {
throw new YamlParseException(
"PS section should not be defined when PyTorch " +
"is the selected framework!");
}
if (isCommandRunJob() && isFrameworkPyTorch() &&
isTensorboardSectionDefined(yamlConfig)) {
throw new YamlParseException(
"TensorBoard section should not be defined when PyTorch " +
"is the selected framework!");
}
}
private boolean isCommandRunJob() {
return command == Command.RUN_JOB;
}
private boolean isFrameworkPyTorch() {
return framework == Framework.PYTORCH;
}
private boolean isPsSectionDefined(YamlConfigFile yamlConfig) {
return yamlConfig != null &&
yamlConfig.getRoles() != null &&
yamlConfig.getRoles().getPs() != null;
}
private boolean isTensorboardSectionDefined(YamlConfigFile yamlConfig) {
return yamlConfig != null &&
yamlConfig.getTensorBoard() != null;
}
private Framework determineFrameworkType()
throws ParseException, YarnException {
if (!isCommandRunJob()) {
return null;
}
String frameworkStr = getOptionValue(CliConstants.FRAMEWORK);
if (frameworkStr == null) {
LOG.info("Framework is not defined in config, falling back to " +
"TensorFlow as a default.");
return Framework.TENSORFLOW;
}
Framework framework = Framework.parseByValue(frameworkStr);
if (framework == null) {
if (getConfigType() == ConfigType.CLI) {
throw new ParseException("Failed to parse Framework type! "
+ "Valid values are: " + Framework.getValues());
} else {
throw new YamlParseException(YAML_PARSE_FAILED +
", framework should is defined, but it has an invalid value! " +
"Valid values are: " + Framework.getValues());
}
}
return framework;
}
/**
@ -108,6 +212,8 @@ public final class ParametersHolder {
private void initGenericConfigs(YamlConfigFile yamlConfig,
Map<String, String> yamlConfigs) {
yamlConfigs.put(CliConstants.NAME, yamlConfig.getSpec().getName());
yamlConfigs.put(CliConstants.FRAMEWORK,
yamlConfig.getSpec().getFramework());
Configs configs = yamlConfig.getConfigs();
yamlConfigs.put(CliConstants.INPUT_PATH, configs.getInputPath());
@ -178,13 +284,15 @@ public final class ParametersHolder {
.collect(Collectors.toList());
}
public static ParametersHolder createWithCmdLine(CommandLine cli) {
return new ParametersHolder(cli, null);
public static ParametersHolder createWithCmdLine(CommandLine cli,
Command command) throws ParseException, YarnException {
return new ParametersHolder(cli, null, ConfigType.CLI, command);
}
public static ParametersHolder createWithCmdLineAndYaml(CommandLine cli,
YamlConfigFile yamlConfig) {
return new ParametersHolder(cli, yamlConfig);
YamlConfigFile yamlConfig, Command command) throws ParseException,
YarnException {
return new ParametersHolder(cli, yamlConfig, ConfigType.YAML, command);
}
/**
@ -193,7 +301,7 @@ public final class ParametersHolder {
* @param option Name of the config.
* @return The value of the config
*/
String getOptionValue(String option) throws YarnException {
public String getOptionValue(String option) throws YarnException {
ensureConfigIsDefinedOnce(option, true);
if (onlyDefinedWithCliArgs.contains(option) ||
parsedCommandLine.hasOption(option)) {
@ -208,7 +316,7 @@ public final class ParametersHolder {
* @param option Name of the config.
* @return The values of the config
*/
List<String> getOptionValues(String option) throws YarnException {
public List<String> getOptionValues(String option) throws YarnException {
ensureConfigIsDefinedOnce(option, false);
if (onlyDefinedWithCliArgs.contains(option) ||
parsedCommandLine.hasOption(option)) {
@ -285,7 +393,7 @@ public final class ParametersHolder {
* @return true, if the option is found in the CLI args or in the YAML config,
* false otherwise.
*/
boolean hasOption(String option) {
public boolean hasOption(String option) {
if (onlyDefinedWithCliArgs.contains(option)) {
boolean value = parsedCommandLine.hasOption(option);
if (LOG.isDebugEnabled()) {
@ -312,4 +420,21 @@ public final class ParametersHolder {
"from YAML configuration.", result, option);
return result;
}
public ConfigType getConfigType() {
return configType;
}
public Framework getFramework() {
return framework;
}
public void updateParameters(ClientContext clientContext)
throws ParseException, YarnException, IOException {
parameters.updateParameters(this, clientContext);
}
public BaseParameters getParameters() {
return parameters;
}
}

View File

@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
import java.io.IOException;
import java.util.List;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import com.google.common.collect.Lists;
/**
* Parameters for PyTorch job.
*/
public class PyTorchRunJobParameters extends RunJobParameters {
private static final String CANNOT_BE_DEFINED_FOR_PYTORCH =
"cannot be defined for PyTorch jobs!";
@Override
public void updateParameters(ParametersHolder parametersHolder,
ClientContext clientContext)
throws ParseException, IOException, YarnException {
checkArguments(parametersHolder);
super.updateParameters(parametersHolder, clientContext);
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
this.workerParameters =
getWorkerParameters(clientContext, parametersHolder, input);
this.distributed = determineIfDistributed(workerParameters.getReplicas());
executePostOperations(clientContext);
}
private void checkArguments(ParametersHolder parametersHolder)
throws YarnException, ParseException {
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.N_PS));
} else if (parametersHolder.getOptionValue(CliConstants.PS_RES) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.PS_RES));
} else if (parametersHolder
.getOptionValue(CliConstants.PS_DOCKER_IMAGE) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.PS_DOCKER_IMAGE));
} else if (parametersHolder
.getOptionValue(CliConstants.PS_LAUNCH_CMD) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.PS_LAUNCH_CMD));
} else if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.TENSORBOARD));
} else if (parametersHolder
.getOptionValue(CliConstants.TENSORBOARD_RESOURCES) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.TENSORBOARD_RESOURCES));
} else if (parametersHolder
.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.TENSORBOARD_DOCKER_IMAGE));
}
}
private String getParamCannotBeDefinedErrorMessage(String cliName) {
return String.format(
"Parameter '%s' " + CANNOT_BE_DEFINED_FOR_PYTORCH, cliName);
}
@Override
void executePostOperations(ClientContext clientContext) throws IOException {
// Set default job dir / saved model dir, etc.
setDefaultDirs(clientContext);
replacePatternsInParameters(clientContext);
}
private void replacePatternsInParameters(ClientContext clientContext)
throws IOException {
if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
String afterReplace =
CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
clientContext.getRemoteDirectoryManager());
setWorkerLaunchCmd(afterReplace);
}
}
@Override
public List<String> getLaunchCommands() {
return Lists.newArrayList(getWorkerLaunchCmd());
}
/**
* We only support non-distributed PyTorch integration for now.
* @param nWorkers
* @return
*/
private boolean determineIfDistributed(int nWorkers) {
return false;
}
}

View File

@ -1,3 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -12,7 +28,7 @@
* limitations under the License. See accompanying LICENSE file.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param;
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.CaseFormat;
@ -21,7 +37,14 @@ import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.yaml.snakeyaml.introspector.Property;
import org.yaml.snakeyaml.introspector.PropertyUtils;
@ -34,27 +57,15 @@ import java.util.List;
/**
* Parameters used to run a job
*/
public class RunJobParameters extends RunParameters {
public abstract class RunJobParameters extends RunParameters {
private String input;
private String checkpointPath;
private int numWorkers;
private int numPS;
private Resource workerResource;
private Resource psResource;
private boolean tensorboardEnabled;
private Resource tensorboardResource;
private String tensorboardDockerImage;
private String workerLaunchCmd;
private String psLaunchCmd;
private List<Quicklink> quicklinks = new ArrayList<>();
private List<Localization> localizations = new ArrayList<>();
private String psDockerImage = null;
private String workerDockerImage = null;
private boolean waitJobFinish = false;
private boolean distributed = false;
protected boolean distributed = false;
private boolean securityDisabled = false;
private String keytab;
@ -62,6 +73,9 @@ public class RunJobParameters extends RunParameters {
private boolean distributeKeytab = false;
private List<String> confPairs = new ArrayList<>();
RoleParameters workerParameters =
RoleParameters.createEmpty(TensorFlowRole.WORKER);
@Override
public void updateParameters(ParametersHolder parametersHolder,
ClientContext clientContext)
@ -70,34 +84,6 @@ public class RunJobParameters extends RunParameters {
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
String jobDir = parametersHolder.getOptionValue(
CliConstants.CHECKPOINT_PATH);
int nWorkers = 1;
if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
nWorkers = Integer.parseInt(
parametersHolder.getOptionValue(CliConstants.N_WORKERS));
// Only check null value.
// Training job shouldn't ignore INPUT_PATH option
// But if nWorkers is 0, INPUT_PATH can be ignored because
// user can only run Tensorboard
if (null == input && 0 != nWorkers) {
throw new ParseException("\"--" + CliConstants.INPUT_PATH +
"\" is absent");
}
}
int nPS = 0;
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
nPS = Integer.parseInt(
parametersHolder.getOptionValue(CliConstants.N_PS));
}
// Check #workers and #ps.
// When distributed training is required
if (nWorkers >= 2 && nPS > 0) {
distributed = true;
} else if (nWorkers <= 1 && nPS > 0) {
throw new ParseException("Only specified one worker but non-zero PS, "
+ "please double check.");
}
if (parametersHolder.hasOption(CliConstants.INSECURE_CLUSTER)) {
setSecurityDisabled(true);
@ -109,46 +95,6 @@ public class RunJobParameters extends RunParameters {
CliConstants.PRINCIPAL);
CliUtils.doLoginIfSecure(kerberosKeytab, kerberosPrincipal);
workerResource = null;
if (nWorkers > 0) {
String workerResourceStr = parametersHolder.getOptionValue(
CliConstants.WORKER_RES);
if (workerResourceStr == null) {
throw new ParseException(
"--" + CliConstants.WORKER_RES + " is absent.");
}
workerResource = ResourceUtils.createResourceFromString(
workerResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
}
Resource psResource = null;
if (nPS > 0) {
String psResourceStr = parametersHolder.getOptionValue(
CliConstants.PS_RES);
if (psResourceStr == null) {
throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
}
psResource = ResourceUtils.createResourceFromString(psResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
}
boolean tensorboard = false;
if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
tensorboard = true;
String tensorboardResourceStr = parametersHolder.getOptionValue(
CliConstants.TENSORBOARD_RESOURCES);
if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
}
tensorboardResource = ResourceUtils.createResourceFromString(
tensorboardResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
tensorboardDockerImage = parametersHolder.getOptionValue(
CliConstants.TENSORBOARD_DOCKER_IMAGE);
this.setTensorboardResource(tensorboardResource);
}
if (parametersHolder.hasOption(CliConstants.WAIT_JOB_FINISH)) {
this.waitJobFinish = true;
}
@ -164,16 +110,6 @@ public class RunJobParameters extends RunParameters {
}
}
psDockerImage = parametersHolder.getOptionValue(
CliConstants.PS_DOCKER_IMAGE);
workerDockerImage = parametersHolder.getOptionValue(
CliConstants.WORKER_DOCKER_IMAGE);
String workerLaunchCmd = parametersHolder.getOptionValue(
CliConstants.WORKER_LAUNCH_CMD);
String psLaunchCommand = parametersHolder.getOptionValue(
CliConstants.PS_LAUNCH_CMD);
// Localizations
List<String> localizationsStr = parametersHolder.getOptionValues(
CliConstants.LOCALIZATION);
@ -191,10 +127,6 @@ public class RunJobParameters extends RunParameters {
.getOptionValues(CliConstants.ARG_CONF);
this.setInputPath(input).setCheckpointPath(jobDir)
.setNumPS(nPS).setNumWorkers(nWorkers)
.setPSLaunchCmd(psLaunchCommand).setWorkerLaunchCmd(workerLaunchCmd)
.setPsResource(psResource)
.setTensorboardEnabled(tensorboard)
.setKeytab(kerberosKeytab)
.setPrincipal(kerberosPrincipal)
.setDistributeKeytab(distributeKerberosKeytab)
@ -203,6 +135,39 @@ public class RunJobParameters extends RunParameters {
super.updateParameters(parametersHolder, clientContext);
}
abstract void executePostOperations(ClientContext clientContext)
throws IOException;
void setDefaultDirs(ClientContext clientContext) throws IOException {
// Create directories if needed
String jobDir = getCheckpointPath();
if (jobDir == null) {
jobDir = getJobDir(clientContext);
setCheckpointPath(jobDir);
}
if (getNumWorkers() > 0) {
String savedModelDir = getSavedModelPath();
if (savedModelDir == null) {
savedModelDir = jobDir;
setSavedModelPath(savedModelDir);
}
}
}
private String getJobDir(ClientContext clientContext) throws IOException {
RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
if (getNumWorkers() > 0) {
return rdm.getJobCheckpointDir(getName(), true).toString();
} else {
// when #workers == 0, it means we only launch TB. In that case,
// point job dir to root dir so all job's metrics will be shown.
return rdm.getUserRootFolder().toString();
}
}
public abstract List<String> getLaunchCommands();
public String getInputPath() {
return input;
}
@ -221,110 +186,10 @@ public class RunJobParameters extends RunParameters {
return this;
}
public int getNumWorkers() {
return numWorkers;
}
public RunJobParameters setNumWorkers(int numWorkers) {
this.numWorkers = numWorkers;
return this;
}
public int getNumPS() {
return numPS;
}
public RunJobParameters setNumPS(int numPS) {
this.numPS = numPS;
return this;
}
public Resource getWorkerResource() {
return workerResource;
}
public RunJobParameters setWorkerResource(Resource workerResource) {
this.workerResource = workerResource;
return this;
}
public Resource getPsResource() {
return psResource;
}
public RunJobParameters setPsResource(Resource psResource) {
this.psResource = psResource;
return this;
}
public boolean isTensorboardEnabled() {
return tensorboardEnabled;
}
public RunJobParameters setTensorboardEnabled(boolean tensorboardEnabled) {
this.tensorboardEnabled = tensorboardEnabled;
return this;
}
public String getWorkerLaunchCmd() {
return workerLaunchCmd;
}
public RunJobParameters setWorkerLaunchCmd(String workerLaunchCmd) {
this.workerLaunchCmd = workerLaunchCmd;
return this;
}
public String getPSLaunchCmd() {
return psLaunchCmd;
}
public RunJobParameters setPSLaunchCmd(String psLaunchCmd) {
this.psLaunchCmd = psLaunchCmd;
return this;
}
public boolean isWaitJobFinish() {
return waitJobFinish;
}
public String getPsDockerImage() {
return psDockerImage;
}
public void setPsDockerImage(String psDockerImage) {
this.psDockerImage = psDockerImage;
}
public String getWorkerDockerImage() {
return workerDockerImage;
}
public void setWorkerDockerImage(String workerDockerImage) {
this.workerDockerImage = workerDockerImage;
}
public boolean isDistributed() {
return distributed;
}
public Resource getTensorboardResource() {
return tensorboardResource;
}
public void setTensorboardResource(Resource tensorboardResource) {
this.tensorboardResource = tensorboardResource;
}
public String getTensorboardDockerImage() {
return tensorboardDockerImage;
}
public void setTensorboardDockerImage(String tensorboardDockerImage) {
this.tensorboardDockerImage = tensorboardDockerImage;
}
public List<Quicklink> getQuicklinks() {
return quicklinks;
}
@ -382,6 +247,90 @@ public class RunJobParameters extends RunParameters {
this.distributed = distributed;
}
RoleParameters getWorkerParameters(ClientContext clientContext,
ParametersHolder parametersHolder, String input)
throws ParseException, YarnException, IOException {
int nWorkers = getNumberOfWorkers(parametersHolder, input);
Resource workerResource =
determineWorkerResource(parametersHolder, nWorkers, clientContext);
String workerDockerImage =
parametersHolder.getOptionValue(CliConstants.WORKER_DOCKER_IMAGE);
String workerLaunchCmd =
parametersHolder.getOptionValue(CliConstants.WORKER_LAUNCH_CMD);
return new RoleParameters(TensorFlowRole.WORKER, nWorkers,
workerLaunchCmd, workerDockerImage, workerResource);
}
private Resource determineWorkerResource(ParametersHolder parametersHolder,
int nWorkers, ClientContext clientContext)
throws ParseException, YarnException, IOException {
if (nWorkers > 0) {
String workerResourceStr =
parametersHolder.getOptionValue(CliConstants.WORKER_RES);
if (workerResourceStr == null) {
throw new ParseException(
"--" + CliConstants.WORKER_RES + " is absent.");
}
return ResourceUtils.createResourceFromString(workerResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
}
return null;
}
private int getNumberOfWorkers(ParametersHolder parametersHolder,
String input) throws ParseException, YarnException {
int nWorkers = 1;
if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
nWorkers = Integer
.parseInt(parametersHolder.getOptionValue(CliConstants.N_WORKERS));
// Only check null value.
// Training job shouldn't ignore INPUT_PATH option
// But if nWorkers is 0, INPUT_PATH can be ignored because
// user can only run Tensorboard
if (null == input && 0 != nWorkers) {
throw new ParseException(
"\"--" + CliConstants.INPUT_PATH + "\" is absent");
}
}
return nWorkers;
}
public String getWorkerLaunchCmd() {
return workerParameters.getLaunchCommand();
}
public void setWorkerLaunchCmd(String launchCmd) {
workerParameters.setLaunchCommand(launchCmd);
}
public int getNumWorkers() {
return workerParameters.getReplicas();
}
public void setNumWorkers(int numWorkers) {
workerParameters.setReplicas(numWorkers);
}
public Resource getWorkerResource() {
return workerParameters.getResource();
}
public void setWorkerResource(Resource resource) {
workerParameters.setResource(resource);
}
public String getWorkerDockerImage() {
return workerParameters.getDockerImage();
}
public void setWorkerDockerImage(String image) {
workerParameters.setDockerImage(image);
}
public boolean isDistributed() {
return distributed;
}
@VisibleForTesting
public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
@Override

View File

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
import com.google.common.collect.Lists;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import java.io.IOException;
import java.util.List;
/**
* Parameters for TensorFlow job.
*/
public class TensorFlowRunJobParameters extends RunJobParameters {
private boolean tensorboardEnabled;
private RoleParameters psParameters =
RoleParameters.createEmpty(TensorFlowRole.PS);
private RoleParameters tensorBoardParameters =
RoleParameters.createEmpty(TensorFlowRole.TENSORBOARD);
@Override
public void updateParameters(ParametersHolder parametersHolder,
ClientContext clientContext)
throws ParseException, IOException, YarnException {
super.updateParameters(parametersHolder, clientContext);
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
this.workerParameters =
getWorkerParameters(clientContext, parametersHolder, input);
this.psParameters = getPSParameters(clientContext, parametersHolder);
this.distributed = determineIfDistributed(workerParameters.getReplicas(),
psParameters.getReplicas());
if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
this.tensorboardEnabled = true;
this.tensorBoardParameters =
getTensorBoardParameters(parametersHolder, clientContext);
}
executePostOperations(clientContext);
}
@Override
void executePostOperations(ClientContext clientContext) throws IOException {
// Set default job dir / saved model dir, etc.
setDefaultDirs(clientContext);
replacePatternsInParameters(clientContext);
}
private void replacePatternsInParameters(ClientContext clientContext)
throws IOException {
if (StringUtils.isNotEmpty(getPSLaunchCmd())) {
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
getPSLaunchCmd(), this, clientContext.getRemoteDirectoryManager());
setPSLaunchCmd(afterReplace);
}
if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
String afterReplace =
CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
clientContext.getRemoteDirectoryManager());
setWorkerLaunchCmd(afterReplace);
}
}
@Override
public List<String> getLaunchCommands() {
return Lists.newArrayList(getWorkerLaunchCmd(), getPSLaunchCmd());
}
private boolean determineIfDistributed(int nWorkers, int nPS)
throws ParseException {
// Check #workers and #ps.
// When distributed training is required
if (nWorkers >= 2 && nPS > 0) {
return true;
} else if (nWorkers <= 1 && nPS > 0) {
throw new ParseException("Only specified one worker but non-zero PS, "
+ "please double check.");
}
return false;
}
private RoleParameters getPSParameters(ClientContext clientContext,
ParametersHolder parametersHolder)
throws YarnException, IOException, ParseException {
int nPS = getNumberOfPS(parametersHolder);
Resource psResource =
determinePSResource(parametersHolder, nPS, clientContext);
String psDockerImage =
parametersHolder.getOptionValue(CliConstants.PS_DOCKER_IMAGE);
String psLaunchCommand =
parametersHolder.getOptionValue(CliConstants.PS_LAUNCH_CMD);
return new RoleParameters(TensorFlowRole.PS, nPS, psLaunchCommand,
psDockerImage, psResource);
}
private Resource determinePSResource(ParametersHolder parametersHolder,
int nPS, ClientContext clientContext)
throws ParseException, YarnException, IOException {
if (nPS > 0) {
String psResourceStr =
parametersHolder.getOptionValue(CliConstants.PS_RES);
if (psResourceStr == null) {
throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
}
return ResourceUtils.createResourceFromString(psResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
}
return null;
}
private int getNumberOfPS(ParametersHolder parametersHolder)
throws YarnException {
int nPS = 0;
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
nPS =
Integer.parseInt(parametersHolder.getOptionValue(CliConstants.N_PS));
}
return nPS;
}
private RoleParameters getTensorBoardParameters(
ParametersHolder parametersHolder, ClientContext clientContext)
throws YarnException, IOException {
String tensorboardResourceStr =
parametersHolder.getOptionValue(CliConstants.TENSORBOARD_RESOURCES);
if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
}
Resource tensorboardResource =
ResourceUtils.createResourceFromString(tensorboardResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
String tensorboardDockerImage =
parametersHolder.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE);
return new RoleParameters(TensorFlowRole.TENSORBOARD, 1, null,
tensorboardDockerImage, tensorboardResource);
}
public int getNumPS() {
return psParameters.getReplicas();
}
public void setNumPS(int numPS) {
psParameters.setReplicas(numPS);
}
public Resource getPsResource() {
return psParameters.getResource();
}
public void setPsResource(Resource resource) {
psParameters.setResource(resource);
}
public String getPsDockerImage() {
return psParameters.getDockerImage();
}
public void setPsDockerImage(String image) {
psParameters.setDockerImage(image);
}
public String getPSLaunchCmd() {
return psParameters.getLaunchCommand();
}
public void setPSLaunchCmd(String launchCmd) {
psParameters.setLaunchCommand(launchCmd);
}
public boolean isTensorboardEnabled() {
return tensorboardEnabled;
}
public Resource getTensorboardResource() {
return tensorBoardParameters.getResource();
}
public void setTensorboardResource(Resource resource) {
tensorBoardParameters.setResource(resource);
}
public String getTensorboardDockerImage() {
return tensorBoardParameters.getDockerImage();
}
public void setTensorboardDockerImage(String image) {
tensorBoardParameters.setDockerImage(image);
}
}

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes that hold run job parameters for
* TensorFlow / PyTorch jobs.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;

View File

@ -22,6 +22,7 @@ package org.apache.hadoop.yarn.submarine.client.cli.param.yaml;
public class Spec {
private String name;
private String jobType;
private String framework;
public String getJobType() {
return jobType;
@ -38,4 +39,12 @@ public class Spec {
public void setName(String name) {
this.name = name;
}
public String getFramework() {
return framework;
}
public void setFramework(String framework) {
this.framework = framework;
}
}

View File

@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.stream.Collectors;
/**
* Represents the type of Machine learning framework to work with.
*/
public enum Framework {
TENSORFLOW(Constants.TENSORFLOW_NAME), PYTORCH(Constants.PYTORCH_NAME);
private String value;
Framework(String value) {
this.value = value;
}
public String getValue() {
return value;
}
public static Framework parseByValue(String value) {
for (Framework fw : Framework.values()) {
if (fw.value.equalsIgnoreCase(value)) {
return fw;
}
}
return null;
}
public static String getValues() {
List<String> values = Lists.newArrayList(Framework.values()).stream()
.map(fw -> fw.value).collect(Collectors.toList());
return String.join(",", values);
}
private static class Constants {
static final String TENSORFLOW_NAME = "tensorflow";
static final String PYTORCH_NAME = "pytorch";
}
}

View File

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.submarine.common.api.Role;
/**
* This class encapsulates data related to a particular Role.
* Some examples: TF Worker process, TF PS process or a PyTorch worker process.
*/
public class RoleParameters {
private final Role role;
private int replicas;
private String launchCommand;
private String dockerImage;
private Resource resource;
public RoleParameters(Role role, int replicas,
String launchCommand, String dockerImage, Resource resource) {
this.role = role;
this.replicas = replicas;
this.launchCommand = launchCommand;
this.dockerImage = dockerImage;
this.resource = resource;
}
public static RoleParameters createEmpty(Role role) {
return new RoleParameters(role, 0, null, null, null);
}
public Role getRole() {
return role;
}
public int getReplicas() {
return replicas;
}
public String getLaunchCommand() {
return launchCommand;
}
public void setLaunchCommand(String launchCommand) {
this.launchCommand = launchCommand;
}
public String getDockerImage() {
return dockerImage;
}
public void setDockerImage(String dockerImage) {
this.dockerImage = dockerImage;
}
public Resource getResource() {
return resource;
}
public void setResource(Resource resource) {
this.resource = resource;
}
public void setReplicas(int replicas) {
this.replicas = replicas;
}
}

View File

@ -1,3 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -12,7 +28,7 @@
* limitations under the License. See accompanying LICENSE file.
*/
package org.apache.hadoop.yarn.submarine.client.cli;
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.cli.CommandLine;
@ -23,9 +39,13 @@ import org.apache.commons.cli.ParseException;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.AbstractCli;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
import org.apache.hadoop.yarn.submarine.client.cli.Command;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
@ -44,17 +64,25 @@ import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/**
* This purpose of this class is to handle / parse CLI arguments related to
* the run job Submarine command.
*/
public class RunJobCli extends AbstractCli {
private static final Logger LOG =
LoggerFactory.getLogger(RunJobCli.class);
private static final String YAML_PARSE_FAILED = "Failed to parse " +
private static final String CAN_BE_USED_WITH_TF_PYTORCH =
"Can be used with TensorFlow or PyTorch frameworks.";
private static final String CAN_BE_USED_WITH_TF_ONLY =
"Can only be used with TensorFlow framework.";
public static final String YAML_PARSE_FAILED = "Failed to parse " +
"YAML config";
private Options options;
private RunJobParameters parameters = new RunJobParameters();
private Options options;
private JobSubmitter jobSubmitter;
private JobMonitor jobMonitor;
private ParametersHolder parametersHolder;
public RunJobCli(ClientContext cliContext) {
this(cliContext, cliContext.getRuntimeFactory().getJobSubmitterInstance(),
@ -62,7 +90,7 @@ public class RunJobCli extends AbstractCli {
}
@VisibleForTesting
RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
public RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
JobMonitor jobMonitor) {
super(cliContext);
this.options = generateOptions();
@ -78,6 +106,10 @@ public class RunJobCli extends AbstractCli {
Options options = new Options();
options.addOption(CliConstants.YAML_CONFIG, true,
"Config file (in YAML format)");
options.addOption(CliConstants.FRAMEWORK, true,
String.format("Framework to use. Valid values are: %s! " +
"The default framework is Tensorflow.",
Framework.getValues()));
options.addOption(CliConstants.NAME, true, "Name of the job");
options.addOption(CliConstants.INPUT_PATH, true,
"Input of the job, could be local or other FS directory");
@ -88,48 +120,22 @@ public class RunJobCli extends AbstractCli {
options.addOption(CliConstants.SAVED_MODEL_PATH, true,
"Model exported path (savedmodel) of the job, which is needed when "
+ "exported model is not placed under ${checkpoint_path}"
+ "could be local or other FS directory. This will be used to serve.");
options.addOption(CliConstants.N_WORKERS, true,
"Number of worker tasks of the job, by default it's 1");
options.addOption(CliConstants.N_PS, true,
"Number of PS tasks of the job, by default it's 0");
options.addOption(CliConstants.WORKER_RES, true,
"Resource of each worker, for example "
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2");
options.addOption(CliConstants.PS_RES, true,
"Resource of each PS, for example "
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2");
+ "could be local or other FS directory. " +
"This will be used to serve.");
options.addOption(CliConstants.DOCKER_IMAGE, true, "Docker image name/tag");
options.addOption(CliConstants.QUEUE, true,
"Name of queue to run the job, by default it uses default queue");
options.addOption(CliConstants.TENSORBOARD, false,
"Should we run TensorBoard"
+ " for this job? By default it's disabled");
options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
"Specify resources of Tensorboard, by default it is "
+ CliConstants.TENSORBOARD_DEFAULT_RESOURCES);
options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
"Specify Tensorboard docker image. when this is not "
+ "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
+ " as default.");
options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
"Commandline of worker, arguments will be "
+ "directly used to launch the worker");
options.addOption(CliConstants.PS_LAUNCH_CMD, true,
"Commandline of worker, arguments will be "
+ "directly used to launch the PS");
addWorkerOptions(options);
addPSOptions(options);
addTensorboardOptions(options);
options.addOption(CliConstants.ENV, true,
"Common environment variable of worker/ps");
options.addOption(CliConstants.VERBOSE, false,
"Print verbose log for troubleshooting");
options.addOption(CliConstants.WAIT_JOB_FINISH, false,
"Specified when user want to wait the job finish");
options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
"Specify docker image for PS, when this is not specified, PS uses --"
+ CliConstants.DOCKER_IMAGE + " as default.");
options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
"Specify docker image for WORKER, when this is not specified, WORKER "
+ "uses --" + CliConstants.DOCKER_IMAGE + " as default.");
options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN"
+ "web UI shows link to given role instance and port. When "
+ "--tensorboard is specified, quicklink to tensorboard instance will "
@ -172,63 +178,97 @@ public class RunJobCli extends AbstractCli {
return options;
}
private void replacePatternsInParameters() throws IOException {
if (parameters.getPSLaunchCmd() != null && !parameters.getPSLaunchCmd()
.isEmpty()) {
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
parameters.getPSLaunchCmd(), parameters,
clientContext.getRemoteDirectoryManager());
parameters.setPSLaunchCmd(afterReplace);
}
private void addWorkerOptions(Options options) {
options.addOption(CliConstants.N_WORKERS, true,
"Number of worker tasks of the job, by default it's 1." +
CAN_BE_USED_WITH_TF_PYTORCH);
options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
"Specify docker image for WORKER, when this is not specified, WORKER "
+ "uses --" + CliConstants.DOCKER_IMAGE + " as default." +
CAN_BE_USED_WITH_TF_PYTORCH);
options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
"Commandline of worker, arguments will be "
+ "directly used to launch the worker" +
CAN_BE_USED_WITH_TF_PYTORCH);
options.addOption(CliConstants.WORKER_RES, true,
"Resource of each worker, for example "
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
CAN_BE_USED_WITH_TF_PYTORCH);
}
if (parameters.getWorkerLaunchCmd() != null && !parameters
.getWorkerLaunchCmd().isEmpty()) {
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
parameters.getWorkerLaunchCmd(), parameters,
clientContext.getRemoteDirectoryManager());
parameters.setWorkerLaunchCmd(afterReplace);
}
private void addPSOptions(Options options) {
options.addOption(CliConstants.N_PS, true,
"Number of PS tasks of the job, by default it's 0. " +
CAN_BE_USED_WITH_TF_ONLY);
options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
"Specify docker image for PS, when this is not specified, PS uses --"
+ CliConstants.DOCKER_IMAGE + " as default." +
CAN_BE_USED_WITH_TF_ONLY);
options.addOption(CliConstants.PS_LAUNCH_CMD, true,
"Commandline of worker, arguments will be "
+ "directly used to launch the PS" +
CAN_BE_USED_WITH_TF_ONLY);
options.addOption(CliConstants.PS_RES, true,
"Resource of each PS, for example "
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
CAN_BE_USED_WITH_TF_ONLY);
}
private void addTensorboardOptions(Options options) {
options.addOption(CliConstants.TENSORBOARD, false,
"Should we run TensorBoard"
+ " for this job? By default it's disabled." +
CAN_BE_USED_WITH_TF_ONLY);
options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
"Specify resources of Tensorboard, by default it is "
+ CliConstants.TENSORBOARD_DEFAULT_RESOURCES + "." +
CAN_BE_USED_WITH_TF_ONLY);
options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
"Specify Tensorboard docker image. when this is not "
+ "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
+ " as default." +
CAN_BE_USED_WITH_TF_ONLY);
}
private void parseCommandLineAndGetRunJobParameters(String[] args)
throws ParseException, IOException, YarnException {
try {
// Do parsing
GnuParser parser = new GnuParser();
CommandLine cli = parser.parse(options, args);
ParametersHolder parametersHolder = createParametersHolder(cli);
parameters.updateParameters(parametersHolder, clientContext);
parametersHolder = createParametersHolder(cli);
parametersHolder.updateParameters(clientContext);
} catch (ParseException e) {
LOG.error("Exception in parse: {}", e.getMessage());
printUsages();
throw e;
}
// Set default job dir / saved model dir, etc.
setDefaultDirs();
// replace patterns
replacePatternsInParameters();
}
private ParametersHolder createParametersHolder(CommandLine cli) {
private ParametersHolder createParametersHolder(CommandLine cli)
throws ParseException, YarnException {
String yamlConfigFile =
cli.getOptionValue(CliConstants.YAML_CONFIG);
if (yamlConfigFile != null) {
YamlConfigFile yamlConfig = readYamlConfigFile(yamlConfigFile);
if (yamlConfig == null) {
throw new YamlParseException(String.format(
YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
} else if (yamlConfig.getConfigs() == null) {
throw new YamlParseException(String.format(YAML_PARSE_FAILED +
", config section should be defined, but it cannot be found in " +
"YAML file '%s'!", yamlConfigFile));
}
checkYamlConfig(yamlConfigFile, yamlConfig);
LOG.info("Using YAML configuration!");
return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig);
return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig,
Command.RUN_JOB);
} else {
LOG.info("Using CLI configuration!");
return ParametersHolder.createWithCmdLine(cli);
return ParametersHolder.createWithCmdLine(cli, Command.RUN_JOB);
}
}
private void checkYamlConfig(String yamlConfigFile,
YamlConfigFile yamlConfig) {
if (yamlConfig == null) {
throw new YamlParseException(String.format(
YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
} else if (yamlConfig.getConfigs() == null) {
throw new YamlParseException(String.format(YAML_PARSE_FAILED +
", config section should be defined, but it cannot be found in " +
"YAML file '%s'!", yamlConfigFile));
}
}
@ -256,34 +296,9 @@ public class RunJobCli extends AbstractCli {
e);
}
private void setDefaultDirs() throws IOException {
// Create directories if needed
String jobDir = parameters.getCheckpointPath();
if (null == jobDir) {
if (parameters.getNumWorkers() > 0) {
jobDir = clientContext.getRemoteDirectoryManager().getJobCheckpointDir(
parameters.getName(), true).toString();
} else {
// when #workers == 0, it means we only launch TB. In that case,
// point job dir to root dir so all job's metrics will be shown.
jobDir = clientContext.getRemoteDirectoryManager().getUserRootFolder()
.toString();
}
parameters.setCheckpointPath(jobDir);
}
if (parameters.getNumWorkers() > 0) {
// Only do this when #worker > 0
String savedModelDir = parameters.getSavedModelPath();
if (null == savedModelDir) {
savedModelDir = jobDir;
parameters.setSavedModelPath(savedModelDir);
}
}
}
private void storeJobInformation(String jobName, ApplicationId applicationId,
String[] args) throws IOException {
private void storeJobInformation(RunJobParameters parameters,
ApplicationId applicationId, String[] args) throws IOException {
String jobName = parameters.getName();
Map<String, String> jobInfo = new HashMap<>();
jobInfo.put(StorageKeyConstants.JOB_NAME, jobName);
jobInfo.put(StorageKeyConstants.APPLICATION_ID, applicationId.toString());
@ -316,8 +331,10 @@ public class RunJobCli extends AbstractCli {
}
parseCommandLineAndGetRunJobParameters(args);
ApplicationId applicationId = this.jobSubmitter.submitJob(parameters);
storeJobInformation(parameters.getName(), applicationId, args);
ApplicationId applicationId = jobSubmitter.submitJob(parametersHolder);
RunJobParameters parameters =
(RunJobParameters) parametersHolder.getParameters();
storeJobInformation(parameters, applicationId, args);
if (parameters.isWaitJobFinish()) {
this.jobMonitor.waitTrainingFinal(parameters.getName());
}
@ -332,6 +349,6 @@ public class RunJobCli extends AbstractCli {
@VisibleForTesting
public RunJobParameters getRunJobParameters() {
return parameters;
return (RunJobParameters) parametersHolder.getParameters();
}
}

View File

@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes that are related to the run job command.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;

View File

@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. See accompanying LICENSE file.
*/
package org.apache.hadoop.yarn.submarine.common.api;
/**
* Enum to represent a PyTorch Role.
*/
public enum PyTorchRole implements Role {
PRIMARY_WORKER("master"),
WORKER("worker");
private String compName;
PyTorchRole(String compName) {
this.compName = compName;
}
public String getComponentName() {
return compName;
}
@Override
public String getName() {
return name();
}
}

View File

@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.common.api;
/**
* Interface for a Role.
*/
public interface Role {
String getComponentName();
String getName();
}

View File

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.common.api;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.stream.Collectors;
/**
* Represents the type of Runtime.
*/
public enum Runtime {
TONY(Constants.TONY), YARN_SERVICE(Constants.YARN_SERVICE);
private String value;
Runtime(String value) {
this.value = value;
}
public String getValue() {
return value;
}
public static Runtime parseByValue(String value) {
for (Runtime rt : Runtime.values()) {
if (rt.value.equalsIgnoreCase(value)) {
return rt;
}
}
return null;
}
public static String getValues() {
List<String> values = Lists.newArrayList(Runtime.values()).stream()
.map(rt -> rt.value).collect(Collectors.toList());
return String.join(",", values);
}
public static class Constants {
public static final String TONY = "tony";
public static final String YARN_SERVICE = "yarnservice";
}
}

View File

@ -14,7 +14,10 @@
package org.apache.hadoop.yarn.submarine.common.api;
public enum TaskType {
/**
* Enum to represent a TensorFlow Role.
*/
public enum TensorFlowRole implements Role {
PRIMARY_WORKER("master"),
WORKER("worker"),
PS("ps"),
@ -22,11 +25,17 @@ public enum TaskType {
private String compName;
TaskType(String compName) {
TensorFlowRole(String compName) {
this.compName = compName;
}
@Override
public String getComponentName() {
return compName;
}
@Override
public String getName() {
return name();
}
}

View File

@ -16,21 +16,21 @@ package org.apache.hadoop.yarn.submarine.runtimes.common;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import java.io.IOException;
/**
* Submit job to cluster master
* Submit job to cluster master.
*/
public interface JobSubmitter {
/**
* Submit job to cluster
* Submit a job to cluster.
* @param parameters run job parameters
* @return applicatioId when successfully submitted
* @return applicationId when successfully submitted
* @throws YarnException for issues while contacting YARN daemons
* @throws IOException for other issues.
*/
ApplicationId submitJob(RunJobParameters parameters)
ApplicationId submitJob(ParametersHolder parameters)
throws IOException, YarnException;
}

View File

@ -40,6 +40,10 @@ More details, please refer to
```$xslt
usage: job run
-framework <arg> Framework to use.
Valid values are: tensorflow, pytorch.
The default framework is Tensorflow.
-checkpoint_path <arg> Training output directory of the job, could
be local or other FS directory. This
typically includes checkpoint files and
@ -130,6 +134,7 @@ For submarine internal configuration, please create a `submarine.xml` which shou
#### Commandline
```
yarn jar path-to/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar job run \
--framework tensorflow \
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current --name tf-job-001 \
--docker_image <your-docker-image> \
@ -163,6 +168,7 @@ See below screenshot:
```
yarn jar hadoop-yarn-applications-submarine-<version>.jar job run \
--name tf-job-001 --docker_image <your-docker-image> \
--framework tensorflow \
--input_path hdfs://default/dataset/cifar-10-data \
--checkpoint_path hdfs://default/tmp/cifar-10-jobdir \
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
@ -208,6 +214,7 @@ After that, you can run ```tensorboard --logdir=<checkpoint-path>``` to view Ten
yarn app -destroy tensorboard-service; \
yarn jar /tmp/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar \
job run --name tensorboard-service --verbose --docker_image <your-docker-image> \
--framework tensorflow \
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current \
--num_workers 0 --tensorboard

View File

@ -1,226 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
import org.apache.hadoop.yarn.util.resource.Resources;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TestRunJobCliParsing {
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Test
public void testPrintHelp() {
MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
JobMonitor mockJobMonitor = mock(JobMonitor.class);
RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
mockJobMonitor);
runJobCli.printUsages();
}
static MockClientContext getMockClientContext()
throws IOException, YarnException {
MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
ApplicationId.newInstance(1234L, 1));
JobMonitor mockJobMonitor = mock(JobMonitor.class);
SubmarineStorage storage = mock(SubmarineStorage.class);
RuntimeFactory rtFactory = mock(RuntimeFactory.class);
when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
when(rtFactory.getSubmarineStorage()).thenReturn(storage);
mockClientContext.setRuntimeFactory(rtFactory);
return mockClientContext;
}
@Test
public void testBasicRunJobForDistributedTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
"--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
"--principal", "user/_HOST@domain.com", "--distribute_keytab",
"--verbose" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(jobRunParameters.getNumPS(), 2);
assertEquals(jobRunParameters.getPSLaunchCmd(), "python run-ps.py");
assertEquals(Resources.createResource(4096, 4),
jobRunParameters.getPsResource());
assertEquals(jobRunParameters.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(2048, 2),
jobRunParameters.getWorkerResource());
assertEquals(jobRunParameters.getDockerImageName(),
"tf-docker:1.1.0");
assertEquals(jobRunParameters.getKeytab(),
"/keytab/path");
assertEquals(jobRunParameters.getPrincipal(),
"user/_HOST@domain.com");
Assert.assertTrue(jobRunParameters.isDistributeKeytab());
Assert.assertTrue(SubmarineLogs.isVerbose());
}
@Test
public void testBasicRunJobForSingleNodeTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(jobRunParameters.getNumWorkers(), 1);
assertEquals(jobRunParameters.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(4096, 2),
jobRunParameters.getWorkerResource());
Assert.assertTrue(SubmarineLogs.isVerbose());
Assert.assertTrue(jobRunParameters.isWaitJobFinish());
}
@Test
public void testNoInputPathOptionSpecified() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\" is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--checkpoint_path", "hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
/**
* when only run tensorboard, input_path is not needed
* */
@Test
public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
boolean success = true;
try {
runJobCli.run(
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--num_workers", "0", "--tensorboard", "--verbose",
"--tensorboard_resources", "memory=2G,vcores=2",
"--tensorboard_docker_image", "tb_docker_image:001"});
} catch (ParseException e) {
success = false;
}
Assert.assertTrue(success);
}
@Test
public void testJobWithoutName() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage =
"--" + CliConstants.NAME + " is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--docker_image", "tf-docker:1.1.0",
"--num_workers", "0", "--tensorboard", "--verbose",
"--tensorboard_resources", "memory=2G,vcores=2",
"--tensorboard_docker_image", "tb_docker_image:001"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
@Test
public void testLaunchCommandPatternReplace() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py --input=%input_path% " +
"--model_dir=%checkpoint_path% " +
"--export_dir=%saved_model_path%/savedmodel",
"--worker_resources", "memory=2048,vcores=2", "--ps_resources",
"memory=4096,vcores=4", "--tensorboard", "true", "--ps_launch_cmd",
"python run-ps.py --input=%input_path% " +
"--model_dir=%checkpoint_path%/model",
"--verbose" });
assertEquals(
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+ "--export_dir=hdfs://output/savedmodel",
runJobCli.getRunJobParameters().getWorkerLaunchCmd());
assertEquals(
"python run-ps.py --input=hdfs://input --model_dir=hdfs://output/model",
runJobCli.getRunJobParameters().getPSLaunchCmd());
}
}

View File

@ -17,7 +17,7 @@
package org.apache.hadoop.yarn.submarine.client.cli;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
import org.yaml.snakeyaml.Yaml;
import org.yaml.snakeyaml.constructor.Constructor;
@ -33,13 +33,13 @@ public final class YamlConfigTestUtils {
private YamlConfigTestUtils() {}
static void deleteFile(File file) {
public static void deleteFile(File file) {
if (file != null) {
file.delete();
}
}
static YamlConfigFile readYamlConfigFile(String filename) {
public static YamlConfigFile readYamlConfigFile(String filename) {
Constructor constructor = new Constructor(YamlConfigFile.class);
constructor.setPropertyUtils(new UnderscoreConverterPropertyUtils());
Yaml yaml = new Yaml(constructor);
@ -49,7 +49,8 @@ public final class YamlConfigTestUtils {
return yaml.loadAs(inputStream, YamlConfigFile.class);
}
static File createTempFileWithContents(String filename) throws IOException {
public static File createTempFileWithContents(String filename)
throws IOException {
InputStream inputStream = YamlConfigTestUtils.class
.getClassLoader()
.getResourceAsStream(filename);
@ -58,7 +59,7 @@ public final class YamlConfigTestUtils {
return targetFile;
}
static File createEmptyTempFile() throws IOException {
public static File createEmptyTempFile() throws IOException {
return File.createTempFile("test", ".yaml");
}

View File

@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import org.apache.commons.cli.MissingArgumentException;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.IOException;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* This class contains some test methods to test common functionality
* (including TF / PyTorch) of the run job Submarine command.
*/
public class TestRunJobCliParsingCommon {
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Rule
public ExpectedException expectedException = ExpectedException.none();
public static MockClientContext getMockClientContext()
throws IOException, YarnException {
MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
when(mockJobSubmitter.submitJob(any(ParametersHolder.class)))
.thenReturn(ApplicationId.newInstance(1235L, 1));
JobMonitor mockJobMonitor = mock(JobMonitor.class);
SubmarineStorage storage = mock(SubmarineStorage.class);
RuntimeFactory rtFactory = mock(RuntimeFactory.class);
when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
when(rtFactory.getSubmarineStorage()).thenReturn(storage);
mockClientContext.setRuntimeFactory(rtFactory);
return mockClientContext;
}
@Test
public void testAbsentFrameworkFallsBackToTensorFlow() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
assertTrue("Default Framework should be TensorFlow!",
runJobParameters instanceof TensorFlowRunJobParameters);
}
@Test
public void testEmptyFrameworkOption() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(MissingArgumentException.class);
expectedException.expectMessage("Missing argument for option: framework");
runJobCli.run(
new String[]{"--framework", "--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
}
@Test
public void testInvalidFrameworkOption() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("Failed to parse Framework type");
runJobCli.run(
new String[]{"--framework", "bla", "--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
}
}

View File

@ -0,0 +1,252 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import org.apache.hadoop.yarn.api.records.ResourceInformation;
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
* This class contains some test methods to test common YAML parsing
* functionality (including TF / PyTorch) of the run job Submarine command.
*/
public class TestRunJobCliParsingCommonYaml {
private static final String DIR_NAME = "runjob-common-yaml";
private static final String TF_DIR = "runjob-pytorch-yaml";
private File yamlConfig;
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@After
public void after() {
YamlConfigTestUtils.deleteFile(yamlConfig);
}
@BeforeClass
public static void configureResourceTypes() {
List<ResourceTypeInfo> resTypes = new ArrayList<>(
ResourceUtils.getResourcesTypeInfo());
resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
ResourceUtils.reinitializeResources(resTypes);
}
@Rule
public ExpectedException exception = ExpectedException.none();
@Test
public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
TF_DIR + "/valid-config.yaml");
String[] args = new String[] {"--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"-f", yamlConfig.getAbsolutePath() };
exception.expect(YarnException.class);
exception.expectMessage("defined both with YAML config and with " +
"CLI argument");
runJobCli.run(args);
}
@Test
public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
TF_DIR + "/valid-config.yaml");
String[] args = new String[] {"--name", "my-job",
"--quicklink", "AAA=http://master-0:8321",
"--quicklink", "BBB=http://worker-0:1234",
"-f", yamlConfig.getAbsolutePath()};
exception.expect(YarnException.class);
exception.expectMessage("defined both with YAML config and with " +
"CLI argument");
runJobCli.run(args);
}
@Test
public void testFalseValuesForBooleanFields() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/test-false-values.yaml");
runJobCli.run(
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertFalse(jobRunParameters.isDistributeKeytab());
assertFalse(jobRunParameters.isWaitJobFinish());
assertFalse(tensorFlowParams.isTensorboardEnabled());
}
@Test
public void testWrongIndentation() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/wrong-indentation.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("Failed to parse YAML config, details:");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testWrongFilename() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
exception.expect(YamlParseException.class);
runJobCli.run(
new String[]{"-f", "not-existing", "--verbose"});
}
@Test
public void testEmptyFile() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
exception.expect(YamlParseException.class);
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testNotExistingFile() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
exception.expect(YamlParseException.class);
exception.expectMessage("file does not exist");
runJobCli.run(
new String[]{"-f", "blabla", "--verbose"});
}
@Test
public void testWrongPropertyName() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/wrong-property-name.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("Failed to parse YAML config, details:");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testMissingConfigsSection() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/missing-configs.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("config section should be defined, " +
"but it cannot be found");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testMissingSectionsShouldParsed() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/some-sections-missing.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testAbsentFramework() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/missing-framework.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testEmptyFramework() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/empty-framework.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testInvalidFramework() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/invalid-framework.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("framework should is defined, " +
"but it has an invalid value");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
}

View File

@ -0,0 +1,192 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import com.google.common.collect.Lists;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
/**
* This class contains some test methods to test common CLI parsing
* functionality (including TF / PyTorch) of the run job Submarine command.
*/
@RunWith(Parameterized.class)
public class TestRunJobCliParsingParameterized {
private final Framework framework;
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Parameterized.Parameters
public static Collection<Object[]> data() {
Collection<Object[]> params = new ArrayList<>();
params.add(new Object[]{Framework.TENSORFLOW });
params.add(new Object[]{Framework.PYTORCH });
return params;
}
public TestRunJobCliParsingParameterized(Framework framework) {
this.framework = framework;
}
private String getFrameworkName() {
return framework.getValue();
}
@Test
public void testPrintHelp() {
MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
JobMonitor mockJobMonitor = mock(JobMonitor.class);
RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
mockJobMonitor);
runJobCli.printUsages();
}
@Test
public void testNoInputPathOptionSpecified() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\"" +
" is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--framework", getFrameworkName(),
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--checkpoint_path", "hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--verbose",
"--wait_job_finish"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
@Test
public void testJobWithoutName() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage =
"--" + CliConstants.NAME + " is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--framework", getFrameworkName(),
"--docker_image", "tf-docker:1.1.0",
"--num_workers", "0", "--verbose"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
@Test
public void testLaunchCommandPatternReplace() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
List<String> parameters = Lists.newArrayList("--framework",
getFrameworkName(),
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd", "python run-job.py --input=%input_path% " +
"--model_dir=%checkpoint_path% " +
"--export_dir=%saved_model_path%/savedmodel",
"--worker_resources", "memory=2048,vcores=2");
if (framework == Framework.TENSORFLOW) {
parameters.addAll(Lists.newArrayList(
"--ps_resources", "memory=4096,vcores=4",
"--ps_launch_cmd", "python run-ps.py --input=%input_path% " +
"--model_dir=%checkpoint_path%/model",
"--verbose"));
}
runJobCli.run(parameters.toArray(new String[0]));
RunJobParameters runJobParameters = checkExpectedFrameworkParams(runJobCli);
if (framework == Framework.TENSORFLOW) {
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) runJobParameters;
assertEquals(
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+ "--export_dir=hdfs://output/savedmodel",
tensorFlowParams.getWorkerLaunchCmd());
assertEquals(
"python run-ps.py --input=hdfs://input " +
"--model_dir=hdfs://output/model",
tensorFlowParams.getPSLaunchCmd());
} else if (framework == Framework.PYTORCH) {
PyTorchRunJobParameters pyTorchParameters =
(PyTorchRunJobParameters) runJobParameters;
assertEquals(
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+ "--export_dir=hdfs://output/savedmodel",
pyTorchParameters.getWorkerLaunchCmd());
}
}
private RunJobParameters checkExpectedFrameworkParams(RunJobCli runJobCli) {
RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
if (framework == Framework.TENSORFLOW) {
assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
runJobParameters instanceof TensorFlowRunJobParameters);
} else if (framework == Framework.PYTORCH) {
assertTrue(RunJobParameters.class + " must be an instance of " +
PyTorchRunJobParameters.class,
runJobParameters instanceof PyTorchRunJobParameters);
}
return runJobParameters;
}
}

View File

@ -0,0 +1,209 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.Resources;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
* Test class that verifies the correctness of PyTorch
* CLI configuration parsing.
*/
public class TestRunJobCliParsingPyTorch {
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test
public void testBasicRunJobForSingleNodeTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--verbose",
"--wait_job_finish" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class +
" must be an instance of " +
PyTorchRunJobParameters.class,
jobRunParameters instanceof PyTorchRunJobParameters);
PyTorchRunJobParameters pyTorchParams =
(PyTorchRunJobParameters) jobRunParameters;
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(pyTorchParams.getNumWorkers(), 1);
assertEquals(pyTorchParams.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(4096, 2),
pyTorchParams.getWorkerResource());
assertTrue(SubmarineLogs.isVerbose());
assertTrue(jobRunParameters.isWaitJobFinish());
}
@Test
public void testNumPSCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path","hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--num_ps", "2" });
}
@Test
public void testPSResourcesCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=2048M,vcores=2" });
}
@Test
public void testPSDockerImageCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_docker_image", "psDockerImage" });
}
@Test
public void testPSLaunchCommandCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_launch_cmd", "psLaunchCommand" });
}
@Test
public void testTensorboardCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--tensorboard" });
}
@Test
public void testTensorboardResourcesCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--tensorboard_resources", "memory=2048M,vcores=2" });
}
@Test
public void testTensorboardDockerImageCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--tensorboard_docker_image", "TBDockerImage" });
}
}

View File

@ -0,0 +1,225 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.yarn.api.records.ResourceInformation;
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
/**
* Test class that verifies the correctness of PyTorch
* YAML configuration parsing.
*/
public class TestRunJobCliParsingPyTorchYaml {
private static final String OVERRIDDEN_PREFIX = "overridden_";
private static final String DIR_NAME = "runjob-pytorch-yaml";
private File yamlConfig;
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@After
public void after() {
YamlConfigTestUtils.deleteFile(yamlConfig);
}
@BeforeClass
public static void configureResourceTypes() {
List<ResourceTypeInfo> resTypes = new ArrayList<>(
ResourceUtils.getResourcesTypeInfo());
resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
ResourceUtils.reinitializeResources(resTypes);
}
@Rule
public ExpectedException exception = ExpectedException.none();
private void verifyBasicConfigValues(RunJobParameters jobRunParameters) {
verifyBasicConfigValues(jobRunParameters,
ImmutableList.of("env1=env1Value", "env2=env2Value"));
}
private void verifyBasicConfigValues(RunJobParameters jobRunParameters,
List<String> expectedEnvs) {
assertEquals("testInputPath", jobRunParameters.getInputPath());
assertEquals("testCheckpointPath", jobRunParameters.getCheckpointPath());
assertEquals("testDockerImage", jobRunParameters.getDockerImageName());
assertNotNull(jobRunParameters.getLocalizations());
assertEquals(2, jobRunParameters.getLocalizations().size());
assertNotNull(jobRunParameters.getQuicklinks());
assertEquals(2, jobRunParameters.getQuicklinks().size());
assertTrue(SubmarineLogs.isVerbose());
assertTrue(jobRunParameters.isWaitJobFinish());
for (String env : expectedEnvs) {
assertTrue(String.format(
"%s should be in env list of jobRunParameters!", env),
jobRunParameters.getEnvars().contains(env));
}
}
private void verifyWorkerValues(RunJobParameters jobRunParameters,
String prefix) {
assertTrue(RunJobParameters.class + " must be an instance of " +
PyTorchRunJobParameters.class,
jobRunParameters instanceof PyTorchRunJobParameters);
PyTorchRunJobParameters tensorFlowParams =
(PyTorchRunJobParameters) jobRunParameters;
assertEquals(3, tensorFlowParams.getNumWorkers());
assertEquals(prefix + "testLaunchCmdWorker",
tensorFlowParams.getWorkerLaunchCmd());
assertEquals(prefix + "testDockerImageWorker",
tensorFlowParams.getWorkerDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "2").build()),
tensorFlowParams.getWorkerResource());
}
private void verifySecurityValues(RunJobParameters jobRunParameters) {
assertEquals("keytabPath", jobRunParameters.getKeytab());
assertEquals("testPrincipal", jobRunParameters.getPrincipal());
assertTrue(jobRunParameters.isDistributeKeytab());
}
@Test
public void testValidYamlParsing() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/valid-config.yaml");
runJobCli.run(
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters);
verifyWorkerValues(jobRunParameters, "");
verifySecurityValues(jobRunParameters);
}
@Test
public void testRoleOverrides() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/valid-config-with-overrides.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters);
verifyWorkerValues(jobRunParameters, OVERRIDDEN_PREFIX);
verifySecurityValues(jobRunParameters);
}
@Test
public void testMissingPrincipalUnderSecuritySection() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/security-principal-is-missing.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters);
verifyWorkerValues(jobRunParameters, "");
//Verify security values
assertEquals("keytabPath", jobRunParameters.getKeytab());
assertNull("Principal should be null!", jobRunParameters.getPrincipal());
assertTrue(jobRunParameters.isDistributeKeytab());
}
@Test
public void testMissingEnvs() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/envs-are-missing.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters, ImmutableList.of());
verifyWorkerValues(jobRunParameters, "");
verifySecurityValues(jobRunParameters);
}
@Test
public void testInvalidConfigPsSectionIsDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
exception.expect(YamlParseException.class);
exception.expectMessage("PS section should not be defined " +
"when PyTorch is the selected framework");
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/invalid-config-ps-section.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testInvalidConfigTensorboardSectionIsDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
exception.expect(YamlParseException.class);
exception.expectMessage("TensorBoard section should not be defined " +
"when PyTorch is the selected framework");
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/invalid-config-tensorboard-section.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
}

View File

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.Resources;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
* Test class that verifies the correctness of TensorFlow
* CLI configuration parsing.
*/
public class TestRunJobCliParsingTensorFlow {
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test
public void testNoInputPathOptionSpecified() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH +
"\" is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--framework", "tensorflow",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--checkpoint_path", "hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
@Test
public void testBasicRunJobForDistributedTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--framework", "tensorflow",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
"--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
"--principal", "user/_HOST@domain.com", "--distribute_keytab",
"--verbose" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class +
" must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(tensorFlowParams.getNumPS(), 2);
assertEquals(tensorFlowParams.getPSLaunchCmd(), "python run-ps.py");
assertEquals(Resources.createResource(4096, 4),
tensorFlowParams.getPsResource());
assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(2048, 2),
tensorFlowParams.getWorkerResource());
assertEquals(jobRunParameters.getDockerImageName(),
"tf-docker:1.1.0");
assertEquals(jobRunParameters.getKeytab(),
"/keytab/path");
assertEquals(jobRunParameters.getPrincipal(),
"user/_HOST@domain.com");
assertTrue(jobRunParameters.isDistributeKeytab());
assertTrue(SubmarineLogs.isVerbose());
}
@Test
public void testBasicRunJobForSingleNodeTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--framework", "tensorflow",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class +
" must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(tensorFlowParams.getNumWorkers(), 1);
assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(4096, 2),
tensorFlowParams.getWorkerResource());
assertTrue(SubmarineLogs.isVerbose());
assertTrue(jobRunParameters.isWaitJobFinish());
}
/**
* when only run tensorboard, input_path is not needed
* */
@Test
public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
boolean success = true;
try {
runJobCli.run(
new String[]{"--framework", "tensorflow",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--num_workers", "0", "--tensorboard", "--verbose",
"--tensorboard_resources", "memory=2G,vcores=2",
"--tensorboard_docker_image", "tb_docker_image:001"});
} catch (ParseException e) {
success = false;
}
assertTrue(success);
}
}

View File

@ -15,16 +15,17 @@
*/
package org.apache.hadoop.yarn.submarine.client.cli;
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.hadoop.yarn.api.records.ResourceInformation;
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.junit.After;
@ -39,19 +40,18 @@ import java.io.File;
import java.util.ArrayList;
import java.util.List;
import static org.apache.hadoop.yarn.submarine.client.cli.TestRunJobCliParsing.getMockClientContext;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
/**
* Test class that verifies the correctness of YAML configuration parsing.
* Test class that verifies the correctness of TF YAML configuration parsing.
*/
public class TestRunJobCliParsingYaml {
public class TestRunJobCliParsingTensorFlowYaml {
private static final String OVERRIDDEN_PREFIX = "overridden_";
private static final String DIR_NAME = "runjobcliparsing";
private static final String DIR_NAME = "runjob-tensorflow-yaml";
private File yamlConfig;
@Before
@ -104,27 +104,39 @@ public class TestRunJobCliParsingYaml {
private void verifyPsValues(RunJobParameters jobRunParameters,
String prefix) {
assertEquals(4, jobRunParameters.getNumPS());
assertEquals(prefix + "testLaunchCmdPs", jobRunParameters.getPSLaunchCmd());
assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertEquals(4, tensorFlowParams.getNumPS());
assertEquals(prefix + "testLaunchCmdPs", tensorFlowParams.getPSLaunchCmd());
assertEquals(prefix + "testDockerImagePs",
jobRunParameters.getPsDockerImage());
tensorFlowParams.getPsDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(20500L, 34,
ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "4").build()),
jobRunParameters.getPsResource());
tensorFlowParams.getPsResource());
}
private void verifyWorkerValues(RunJobParameters jobRunParameters,
String prefix) {
assertEquals(3, jobRunParameters.getNumWorkers());
assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertEquals(3, tensorFlowParams.getNumWorkers());
assertEquals(prefix + "testLaunchCmdWorker",
jobRunParameters.getWorkerLaunchCmd());
tensorFlowParams.getWorkerLaunchCmd());
assertEquals(prefix + "testDockerImageWorker",
jobRunParameters.getWorkerDockerImage());
tensorFlowParams.getWorkerDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "2").build()),
jobRunParameters.getWorkerResource());
tensorFlowParams.getWorkerResource());
}
private void verifySecurityValues(RunJobParameters jobRunParameters) {
@ -134,13 +146,19 @@ public class TestRunJobCliParsingYaml {
}
private void verifyTensorboardValues(RunJobParameters jobRunParameters) {
assertTrue(jobRunParameters.isTensorboardEnabled());
assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertTrue(tensorFlowParams.isTensorboardEnabled());
assertEquals("tensorboardDockerImage",
jobRunParameters.getTensorboardDockerImage());
tensorFlowParams.getTensorboardDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "3").build()),
jobRunParameters.getTensorboardResource());
tensorFlowParams.getTensorboardResource());
}
@Test
@ -161,44 +179,6 @@ public class TestRunJobCliParsingYaml {
verifyTensorboardValues(jobRunParameters);
}
@Test
public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/valid-config.yaml");
String[] args = new String[] {"--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"-f", yamlConfig.getAbsolutePath() };
exception.expect(YarnException.class);
exception.expectMessage("defined both with YAML config and with " +
"CLI argument");
runJobCli.run(args);
}
@Test
public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/valid-config.yaml");
String[] args = new String[] {"--name", "my-job",
"--quicklink", "AAA=http://master-0:8321",
"--quicklink", "BBB=http://worker-0:1234",
"-f", yamlConfig.getAbsolutePath()};
exception.expect(YarnException.class);
exception.expectMessage("defined both with YAML config and with " +
"CLI argument");
runJobCli.run(args);
}
@Test
public void testRoleOverrides() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
@ -217,104 +197,6 @@ public class TestRunJobCliParsingYaml {
verifyTensorboardValues(jobRunParameters);
}
@Test
public void testFalseValuesForBooleanFields() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/test-false-values.yaml");
runJobCli.run(
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertFalse(jobRunParameters.isDistributeKeytab());
assertFalse(jobRunParameters.isWaitJobFinish());
assertFalse(jobRunParameters.isTensorboardEnabled());
}
@Test
public void testWrongIndentation() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/wrong-indentation.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("Failed to parse YAML config, details:");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testWrongFilename() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
exception.expect(YamlParseException.class);
runJobCli.run(
new String[]{"-f", "not-existing", "--verbose"});
}
@Test
public void testEmptyFile() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
exception.expect(YamlParseException.class);
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testNotExistingFile() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
exception.expect(YamlParseException.class);
exception.expectMessage("file does not exist");
runJobCli.run(
new String[]{"-f", "blabla", "--verbose"});
}
@Test
public void testWrongPropertyName() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/wrong-property-name.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("Failed to parse YAML config, details:");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testMissingConfigsSection() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/missing-configs.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("config section should be defined, " +
"but it cannot be found");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testMissingSectionsShouldParsed() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/some-sections-missing.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testMissingPrincipalUnderSecuritySection() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
@ -346,18 +228,22 @@ public class TestRunJobCliParsingYaml {
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters);
verifyPsValues(jobRunParameters, "");
verifyWorkerValues(jobRunParameters, "");
verifySecurityValues(jobRunParameters);
assertTrue(jobRunParameters.isTensorboardEnabled());
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertTrue(tensorFlowParams.isTensorboardEnabled());
assertNull("tensorboardDockerImage should be null!",
jobRunParameters.getTensorboardDockerImage());
tensorFlowParams.getTensorboardDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "3").build()),
jobRunParameters.getTensorboardResource());
tensorFlowParams.getTensorboardResource());
}
@Test

View File

@ -15,7 +15,7 @@
*/
package org.apache.hadoop.yarn.submarine.client.cli;
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
@ -42,14 +42,9 @@ import static org.junit.Assert.assertTrue;
* Please note that this class just tests YAML parsing,
* but only in an isolated fashion.
*/
public class TestRunJobCliParsingYamlStandalone {
public class TestRunJobCliParsingTensorFlowYamlStandalone {
private static final String OVERRIDDEN_PREFIX = "overridden_";
private static final String DIR_NAME = "runjobcliparsing";
@Before
public void before() {
SubmarineLogs.verboseOff();
}
private static final String DIR_NAME = "runjob-tensorflow-yaml";
private void verifyBasicConfigValues(YamlConfigFile yamlConfigFile) {
assertNotNull("Spec file should not be null!", yamlConfigFile);
@ -169,6 +164,11 @@ public class TestRunJobCliParsingYamlStandalone {
assertEquals("memory=21000M,vcores=37,gpu=3", tensorBoard.getResources());
}
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Test
public void testLaunchCommandYaml() {
YamlConfigFile yamlConfigFile = readYamlConfigFile(DIR_NAME +
@ -201,5 +201,4 @@ public class TestRunJobCliParsingYamlStandalone {
assertRoleConfigOverrides(roles.getWorker(), OVERRIDDEN_PREFIX, "Worker");
assertRoleConfigOverrides(roles.getPs(), OVERRIDDEN_PREFIX, "Ps");
}
}

View File

@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework:
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
ps:
resources: memory=20500M,vcores=34,gpu=4
replicas: 4
launch_cmd: testLaunchCmdPs
docker_image: testDockerImagePs
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true
tensorBoard:
resources: memory=21000M,vcores=37,gpu=3
docker_image: tensorboardDockerImage

View File

@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: bla
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
ps:
resources: memory=20500M,vcores=34,gpu=4
replicas: 4
launch_cmd: testLaunchCmdPs
docker_image: testDockerImagePs
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true
tensorBoard:
resources: memory=21000M,vcores=37,gpu=3
docker_image: tensorboardDockerImage

View File

@ -17,6 +17,7 @@
spec:
name: testJobName
job_type: testJobType
framework: tensorflow
configs:
input_path: testInputPath

View File

@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true

View File

@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
ps:
docker_image: testPSDockerImage
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true

View File

@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true
tensorBoard:
docker_image: tensorboardDockerImage

View File

@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
security:
keytab: keytabPath
distribute_keytab: true

View File

@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: overridden_testLaunchCmdWorker
docker_image: overridden_testDockerImageWorker
envs:
env1: 'overridden_env1Worker'
env2: 'overridden_env2Worker'
localizations:
- hdfs://remote-file1:/overridden_local-filename1Worker:rw
- nfs://remote-file2:/overridden_local-filename2Worker:rw
mounts:
- /etc/passwd:/overridden_Worker
- /etc/hosts:/overridden_Worker
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true

View File

@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true

View File

@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: tensorflow
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
ps:
resources: memory=20500M,vcores=34,gpu=4
replicas: 4
launch_cmd: testLaunchCmdPs
docker_image: testDockerImagePs
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true
tensorBoard:
resources: memory=21000M,vcores=37,gpu=3
docker_image: tensorboardDockerImage

View File

@ -22,7 +22,9 @@ import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import java.io.File;
@ -45,14 +47,24 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
}
@Override
public ApplicationId submitJob(RunJobParameters parameters)
throws IOException, YarnException {
public ApplicationId submitJob(ParametersHolder parameters)
throws IOException {
if (parameters.getFramework() == Framework.PYTORCH) {
// we need to throw an exception, as ParametersHolder's parameters field
// could not be casted to TensorFlowRunJobParameters.
throw new UnsupportedOperationException(
"Support \"-framework\" option for PyTorch in Tony is coming. " +
"Please check the documentation about how to submit a " +
"PyTorch job with TonY runtime.");
}
LOG.info("Starting Tony runtime..");
File tonyFinalConfPath = File.createTempFile("temp",
Constants.TONY_FINAL_XML);
// Write user's overridden conf to an xml to be localized.
Configuration tonyConf = TonyUtils.tonyConfFromClientContext(parameters);
Configuration tonyConf = TonyUtils.tonyConfFromClientContext(
(TensorFlowRunJobParameters) parameters.getParameters());
try (OutputStream os = new FileOutputStream(tonyFinalConfPath)) {
tonyConf.writeXml(os);
} catch (IOException e) {
@ -68,7 +80,7 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
LOG.error("Failed to init TonyClient: ", e);
}
Thread clientThread = new Thread(tonyClient::start);
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
java.lang.Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
tonyClient.forceKillApplication();
} catch (YarnException | IOException e) {

View File

@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ResourceInformation;
import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import java.util.ArrayList;
import java.util.Arrays;
@ -35,7 +35,7 @@ public final class TonyUtils {
private static final Log LOG = LogFactory.getLog(TonyUtils.class);
public static Configuration tonyConfFromClientContext(
RunJobParameters parameters) {
TensorFlowRunJobParameters parameters) {
Configuration tonyConf = new Configuration();
tonyConf.setInt(
TonyConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME),

View File

@ -147,6 +147,7 @@ CLASSPATH=$(hadoop classpath --glob): \
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
--framework tensorflow \
--num_workers 2 \
--worker_resources memory=3G,vcores=2 \
--num_ps 2 \
@ -183,6 +184,7 @@ CLASSPATH=$(hadoop classpath --glob): \
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
--framework tensorflow \
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
--worker_resources memory=3G,vcores=2 \
@ -245,6 +247,7 @@ CLASSPATH=$(hadoop classpath --glob): \
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
--framework tensorflow \
--num_workers 2 \
--worker_resources memory=3G,vcores=2 \
--num_ps 2 \
@ -281,6 +284,7 @@ CLASSPATH=$(hadoop classpath --glob): \
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
--framework tensorflow \
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
--worker_resources memory=3G,vcores=2 \

View File

@ -16,8 +16,10 @@ import com.linkedin.tony.TonyConfigurationKeys;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
@ -31,6 +33,7 @@ import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -59,7 +62,8 @@ public class TestTonyUtils {
throws IOException, YarnException {
MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
when(mockJobSubmitter.submitJob(
any(ParametersHolder.class))).thenReturn(
ApplicationId.newInstance(1234L, 1));
JobMonitor mockJobMonitor = mock(JobMonitor.class);
SubmarineStorage storage = mock(SubmarineStorage.class);
@ -82,20 +86,28 @@ public class TestTonyUtils {
public void testTonyConfFromClientContext() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
runJobCli.run(
new String[] {"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
new String[] {"--framework", "tensorflow", "--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=4G,vcores=4", "--ps_launch_cmd",
"python run-ps.py"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
Configuration tonyConf = TonyUtils
.tonyConfFromClientContext(jobRunParameters);
.tonyConfFromClientContext(tensorFlowParams);
Assert.assertEquals(jobRunParameters.getDockerImageName(),
tonyConf.get(TonyConfigurationKeys.getContainerDockerKey()));
Assert.assertEquals("3", tonyConf.get(TonyConfigurationKeys
.getInstancesKey("worker")));
Assert.assertEquals(jobRunParameters.getWorkerLaunchCmd(),
Assert.assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
tonyConf.get(TonyConfigurationKeys
.getExecuteCommandKey("worker")));
Assert.assertEquals("2048", tonyConf.get(TonyConfigurationKeys
@ -107,7 +119,7 @@ public class TestTonyUtils {
Assert.assertEquals("4", tonyConf.get(TonyConfigurationKeys
.getResourceKey(Constants.PS_JOB_NAME,
Constants.VCORES)));
Assert.assertEquals(jobRunParameters.getPSLaunchCmd(),
Assert.assertEquals(tensorFlowParams.getPSLaunchCmd(),
tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey("ps")));
}
}

View File

@ -19,8 +19,10 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
@ -28,7 +30,11 @@ import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchComma
import java.io.IOException;
import java.util.Objects;
import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
/**
* Abstract base class for Component classes.
@ -40,7 +46,7 @@ import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.T
public abstract class AbstractComponent {
private final FileSystemOperations fsOperations;
protected final RunJobParameters parameters;
protected final TaskType taskType;
protected final Role role;
private final RemoteDirectoryManager remoteDirectoryManager;
protected final Configuration yarnConfig;
private final LaunchCommandFactory launchCommandFactory;
@ -52,19 +58,55 @@ public abstract class AbstractComponent {
public AbstractComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters, TaskType taskType,
RunJobParameters parameters, Role role,
Configuration yarnConfig,
LaunchCommandFactory launchCommandFactory) {
this.fsOperations = fsOperations;
this.remoteDirectoryManager = remoteDirectoryManager;
this.parameters = parameters;
this.taskType = taskType;
this.role = role;
this.launchCommandFactory = launchCommandFactory;
this.yarnConfig = yarnConfig;
}
protected abstract Component createComponent() throws IOException;
protected Component createComponentInternal() throws IOException {
Objects.requireNonNull(this.parameters.getWorkerResource(),
"Worker resource must not be null!");
if (parameters.getNumWorkers() < 1) {
throw new IllegalArgumentException(
"Number of workers should be at least 1!");
}
Component component = new Component();
component.setName(role.getComponentName());
if (role.equals(TensorFlowRole.PRIMARY_WORKER) ||
role.equals(PyTorchRole.PRIMARY_WORKER)) {
component.setNumberOfContainers(1L);
component.getConfiguration().setProperty(
CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
} else {
component.setNumberOfContainers(
(long) parameters.getNumWorkers() - 1);
}
if (parameters.getWorkerDockerImage() != null) {
component.setArtifact(
getDockerArtifact(parameters.getWorkerDockerImage()));
}
component.setResource(convertYarnResourceToServiceResource(
parameters.getWorkerResource()));
component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
addCommonEnvironments(component, role);
generateLaunchCommand(component);
return component;
}
/**
* Generates a command launch script on local disk,
* returns path to the script.
@ -72,7 +114,7 @@ public abstract class AbstractComponent {
protected void generateLaunchCommand(Component component)
throws IOException {
AbstractLaunchCommand launchCommand =
launchCommandFactory.createLaunchCommand(taskType, component);
launchCommandFactory.createLaunchCommand(role, component);
this.localScriptFile = launchCommand.generateLaunchScript();
String remoteLaunchCommand = uploadLaunchCommand(component);
@ -86,7 +128,7 @@ public abstract class AbstractComponent {
Path stagingDir =
remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
String destScriptFileName = getScriptFileName(taskType);
String destScriptFileName = getScriptFileName(role);
fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
localScriptFile, destScriptFileName, component);

View File

@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
/**
* Abstract base class that supports creating service specs for Native Service.
*/
public abstract class AbstractServiceSpec implements ServiceSpec {
private static final Logger LOG =
LoggerFactory.getLogger(AbstractServiceSpec.class);
protected final RunJobParameters parameters;
protected final FileSystemOperations fsOperations;
private final Localizer localizer;
protected final RemoteDirectoryManager remoteDirectoryManager;
protected final Configuration yarnConfig;
protected final LaunchCommandFactory launchCommandFactory;
private final WorkerComponentFactory workerFactory;
public AbstractServiceSpec(RunJobParameters parameters,
ClientContext clientContext, FileSystemOperations fsOperations,
LaunchCommandFactory launchCommandFactory,
Localizer localizer) {
this.parameters = parameters;
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
this.yarnConfig = clientContext.getYarnConfig();
this.fsOperations = fsOperations;
this.localizer = localizer;
this.launchCommandFactory = launchCommandFactory;
this.workerFactory = new WorkerComponentFactory(fsOperations,
remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
}
protected ServiceWrapper createServiceSpecWrapper() throws IOException {
Service serviceSpec = new Service();
serviceSpec.setName(parameters.getName());
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
.create(fsOperations, remoteDirectoryManager, parameters);
if (kerberosPrincipal != null) {
serviceSpec.setKerberosPrincipal(kerberosPrincipal);
}
handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
localizer.handleLocalizations(serviceSpec);
return new ServiceWrapper(serviceSpec);
}
// Handle worker and primary_worker.
protected void addWorkerComponents(ServiceWrapper serviceWrapper,
Framework framework)
throws IOException {
final Role primaryWorkerRole;
final Role workerRole;
if (framework == Framework.TENSORFLOW) {
primaryWorkerRole = TensorFlowRole.PRIMARY_WORKER;
workerRole = TensorFlowRole.WORKER;
} else {
primaryWorkerRole = PyTorchRole.PRIMARY_WORKER;
workerRole = PyTorchRole.WORKER;
}
addWorkerComponent(serviceWrapper, primaryWorkerRole, framework);
if (parameters.getNumWorkers() > 1) {
addWorkerComponent(serviceWrapper, workerRole, framework);
}
}
private void addWorkerComponent(ServiceWrapper serviceWrapper,
Role role, Framework framework) throws IOException {
AbstractComponent component = workerFactory.create(framework, role);
serviceWrapper.addComponent(component);
}
protected void handleQuicklinks(Service serviceSpec)
throws IOException {
List<Quicklink> quicklinks = parameters.getQuicklinks();
if (quicklinks != null && !quicklinks.isEmpty()) {
for (Quicklink ql : quicklinks) {
// Make sure it is a valid instance name
String instanceName = ql.getComponentInstanceName();
boolean found = false;
for (Component comp : serviceSpec.getComponents()) {
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
String possibleInstanceName = comp.getName() + "-" + i;
if (possibleInstanceName.equals(instanceName)) {
found = true;
break;
}
}
}
if (!found) {
throw new IOException(
"Couldn't find a component instance = " + instanceName
+ " while adding quicklink");
}
String link = ql.getProtocol()
+ YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
getUserName(), getDNSDomain(yarnConfig), ql.getPort());
addQuicklink(serviceSpec, ql.getLabel(), link);
}
}
}
protected static void addQuicklink(Service serviceSpec, String label,
String link) {
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
if (quicklinks == null) {
quicklinks = new HashMap<>();
serviceSpec.setQuicklinks(quicklinks);
}
if (SubmarineLogs.isVerbose()) {
LOG.info("Added quicklink, " + label + "=" + link);
}
quicklinks.put(label, link);
}
}

View File

@ -37,6 +37,7 @@ import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
@ -195,6 +196,15 @@ public class FileSystemOperations {
fs.setPermission(destPath, new FsPermission(permission));
}
public static boolean needHdfs(List<String> stringsToCheck) {
for (String content : stringsToCheck) {
if (content != null && content.contains("hdfs://")) {
return true;
}
}
return false;
}
public static boolean needHdfs(String content) {
return content != null && content.contains("hdfs://");
}

View File

@ -16,9 +16,10 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.curator.shaded.com.google.common.collect.ImmutableList;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
@ -28,6 +29,8 @@ import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
import java.util.Objects;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
@ -128,10 +131,22 @@ public class HadoopEnvironmentSetup {
}
private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
return needHdfs(parameters.getInputPath()) ||
needHdfs(parameters.getPSLaunchCmd()) ||
needHdfs(parameters.getWorkerLaunchCmd()) ||
hadoopEnv;
List<String> launchCommands = parameters.getLaunchCommands();
if (launchCommands != null) {
launchCommands.removeIf(Objects::isNull);
}
ImmutableList.Builder<String> listBuilder = ImmutableList.builder();
if (launchCommands != null && !launchCommands.isEmpty()) {
listBuilder.addAll(launchCommands);
}
if (parameters.getInputPath() != null) {
listBuilder.add(parameters.getInputPath());
}
List<String> stringsToCheck = listBuilder.build();
return needHdfs(stringsToCheck) || hadoopEnv;
}
private void appendHdfsHome(PrintWriter fw, String hdfsHome) {

View File

@ -38,7 +38,7 @@ public final class ServiceSpecFileGenerator {
"instantiated!");
}
static String generateJson(Service service) throws IOException {
public static String generateJson(Service service) throws IOException {
File serviceSpecFile = File.createTempFile(service.getName(), ".json");
String buffer = jsonSerDeser.toJson(service);
Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),

View File

@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component.PyTorchWorkerComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
/**
* Factory class that helps creating Native Service components.
*/
public class WorkerComponentFactory {
private final FileSystemOperations fsOperations;
private final RemoteDirectoryManager remoteDirectoryManager;
private final RunJobParameters parameters;
private final LaunchCommandFactory launchCommandFactory;
private final Configuration yarnConfig;
WorkerComponentFactory(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters,
LaunchCommandFactory launchCommandFactory,
Configuration yarnConfig) {
this.fsOperations = fsOperations;
this.remoteDirectoryManager = remoteDirectoryManager;
this.parameters = parameters;
this.launchCommandFactory = launchCommandFactory;
this.yarnConfig = yarnConfig;
}
/**
* Creates either a TensorFlow or a PyTorch Native Service component.
*/
public AbstractComponent create(Framework framework, Role role) {
if (framework == Framework.TENSORFLOW) {
return new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
(TensorFlowRunJobParameters) parameters, role,
(TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
} else if (framework == Framework.PYTORCH) {
return new PyTorchWorkerComponent(fsOperations, remoteDirectoryManager,
(PyTorchRunJobParameters) parameters, role,
(PyTorchLaunchCommandFactory) launchCommandFactory, yarnConfig);
} else {
throw new UnsupportedOperationException("Only supported frameworks are: "
+ Framework.getValues());
}
}
}

View File

@ -20,10 +20,16 @@ import org.apache.hadoop.yarn.client.api.AppAdminClient;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.PyTorchServiceSpec;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
@ -32,6 +38,7 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
import static org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder.SUPPORTED_FRAMEWORKS_MESSAGE;
/**
* Submit a job to cluster.
@ -51,14 +58,45 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
* {@inheritDoc}
*/
@Override
public ApplicationId submitJob(RunJobParameters parameters)
public ApplicationId submitJob(ParametersHolder paramsHolder)
throws IOException, YarnException {
Framework framework = paramsHolder.getFramework();
RunJobParameters parameters =
(RunJobParameters) paramsHolder.getParameters();
if (framework == Framework.TENSORFLOW) {
return submitTensorFlowJob((TensorFlowRunJobParameters) parameters);
} else if (framework == Framework.PYTORCH) {
return submitPyTorchJob((PyTorchRunJobParameters) parameters);
} else {
throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
}
}
private ApplicationId submitTensorFlowJob(
TensorFlowRunJobParameters parameters) throws IOException, YarnException {
FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(clientContext, fsOperations);
Service serviceSpec = createTensorFlowServiceSpec(parameters,
fsOperations, hadoopEnvSetup);
return submitJobInternal(serviceSpec);
}
private ApplicationId submitPyTorchJob(PyTorchRunJobParameters parameters)
throws IOException, YarnException {
FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(clientContext, fsOperations);
Service serviceSpec = createPyTorchServiceSpec(parameters,
fsOperations, hadoopEnvSetup);
return submitJobInternal(serviceSpec);
}
private ApplicationId submitJobInternal(Service serviceSpec)
throws IOException, YarnException {
String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
AppAdminClient appAdminClient =
@ -70,7 +108,7 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
"Fail to launch application with exit code:" + code);
}
String appStatus=appAdminClient.getStatusString(serviceSpec.getName());
String appStatus = appAdminClient.getStatusString(serviceSpec.getName());
Service app = ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
// Retry multiple times if applicationId is null
@ -97,11 +135,12 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
return appid;
}
private Service createTensorFlowServiceSpec(RunJobParameters parameters,
private Service createTensorFlowServiceSpec(
TensorFlowRunJobParameters parameters,
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
throws IOException {
LaunchCommandFactory launchCommandFactory =
new LaunchCommandFactory(hadoopEnvSetup, parameters,
TensorFlowLaunchCommandFactory launchCommandFactory =
new TensorFlowLaunchCommandFactory(hadoopEnvSetup, parameters,
clientContext.getYarnConfig());
Localizer localizer = new Localizer(fsOperations,
clientContext.getRemoteDirectoryManager(), parameters);
@ -113,6 +152,22 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
return serviceWrapper.getService();
}
private Service createPyTorchServiceSpec(PyTorchRunJobParameters parameters,
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
throws IOException {
PyTorchLaunchCommandFactory launchCommandFactory =
new PyTorchLaunchCommandFactory(hadoopEnvSetup, parameters,
clientContext.getYarnConfig());
Localizer localizer = new Localizer(fsOperations,
clientContext.getRemoteDirectoryManager(), parameters);
PyTorchServiceSpec pyTorchServiceSpec = new PyTorchServiceSpec(
parameters, this.clientContext, fsOperations, launchCommandFactory,
localizer);
serviceWrapper = pyTorchServiceSpec.create();
return serviceWrapper.getService();
}
@VisibleForTesting
public ServiceWrapper getServiceWrapper() {
return serviceWrapper;

View File

@ -17,11 +17,9 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import java.io.IOException;
import java.util.Objects;
/**
* Abstract base class for Launch command implementations for Services.
@ -32,10 +30,9 @@ public abstract class AbstractLaunchCommand {
private final LaunchScriptBuilder builder;
public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters)
throws IOException {
Objects.requireNonNull(taskType, "TaskType must not be null!");
this.builder = new LaunchScriptBuilder(taskType.name(), hadoopEnvSetup,
Component component, RunJobParameters parameters,
String launchCommandPrefix) throws IOException {
this.builder = new LaunchScriptBuilder(launchCommandPrefix, hadoopEnvSetup,
parameters, component);
}

View File

@ -16,52 +16,15 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import java.io.IOException;
import java.util.Objects;
/**
* Simple factory to create instances of {@link AbstractLaunchCommand}
* based on the {@link TaskType}.
* All dependencies are passed to this factory that could be required
* by any implementor of {@link AbstractLaunchCommand}.
* Interface for creating launch commands.
*/
public class LaunchCommandFactory {
private final HadoopEnvironmentSetup hadoopEnvSetup;
private final RunJobParameters parameters;
private final Configuration yarnConfig;
public LaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
RunJobParameters parameters, Configuration yarnConfig) {
this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters;
this.yarnConfig = yarnConfig;
}
public AbstractLaunchCommand createLaunchCommand(TaskType taskType,
Component component) throws IOException {
Objects.requireNonNull(taskType, "TaskType must not be null!");
if (taskType == TaskType.WORKER || taskType == TaskType.PRIMARY_WORKER) {
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
component, parameters, yarnConfig);
} else if (taskType == TaskType.PS) {
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
parameters, yarnConfig);
} else if (taskType == TaskType.TENSORBOARD) {
return new TensorBoardLaunchCommand(hadoopEnvSetup, taskType, component,
parameters);
}
throw new IllegalStateException("Unknown task type: " + taskType);
}
public interface LaunchCommandFactory {
AbstractLaunchCommand createLaunchCommand(Role role, Component component)
throws IOException;
}

View File

@ -17,7 +17,7 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -47,10 +47,11 @@ public class LaunchScriptBuilder {
private final StringBuilder scriptBuffer;
private String launchCommand;
LaunchScriptBuilder(String namePrefix,
LaunchScriptBuilder(String launchScriptPrefix,
HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
Component component) throws IOException {
this.file = File.createTempFile(namePrefix + "-launch-script", ".sh");
this.file = File.createTempFile(launchScriptPrefix +
"-launch-script", ".sh");
this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters;
this.component = component;

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import java.io.IOException;
import java.util.Objects;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command.PyTorchWorkerLaunchCommand;
/**
* Simple factory to create instances of {@link AbstractLaunchCommand}
* based on the {@link Role}.
* All dependencies are passed to this factory that could be required
* by any implementor of {@link AbstractLaunchCommand}.
*/
public class PyTorchLaunchCommandFactory implements LaunchCommandFactory {
private final HadoopEnvironmentSetup hadoopEnvSetup;
private final PyTorchRunJobParameters parameters;
private final Configuration yarnConfig;
public PyTorchLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
PyTorchRunJobParameters parameters, Configuration yarnConfig) {
this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters;
this.yarnConfig = yarnConfig;
}
public AbstractLaunchCommand createLaunchCommand(Role role,
Component component) throws IOException {
Objects.requireNonNull(role, "Role must not be null!");
if (role == PyTorchRole.WORKER ||
role == PyTorchRole.PRIMARY_WORKER) {
return new PyTorchWorkerLaunchCommand(hadoopEnvSetup, role,
component, parameters, yarnConfig);
} else {
throw new IllegalStateException("Unknown task type: " + role);
}
}
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
import java.io.IOException;
import java.util.Objects;
/**
* Simple factory to create instances of {@link AbstractLaunchCommand}
* based on the {@link Role}.
* All dependencies are passed to this factory that could be required
* by any implementor of {@link AbstractLaunchCommand}.
*/
public class TensorFlowLaunchCommandFactory implements LaunchCommandFactory {
private final HadoopEnvironmentSetup hadoopEnvSetup;
private final TensorFlowRunJobParameters parameters;
private final Configuration yarnConfig;
public TensorFlowLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
TensorFlowRunJobParameters parameters, Configuration yarnConfig) {
this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters;
this.yarnConfig = yarnConfig;
}
@Override
public AbstractLaunchCommand createLaunchCommand(Role role,
Component component) throws IOException {
Objects.requireNonNull(role, "Role must not be null!");
if (role == TensorFlowRole.WORKER ||
role == TensorFlowRole.PRIMARY_WORKER) {
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, role,
component, parameters, yarnConfig);
} else if (role == TensorFlowRole.PS) {
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, role, component,
parameters, yarnConfig);
} else if (role == TensorFlowRole.TENSORBOARD) {
return new TensorBoardLaunchCommand(hadoopEnvSetup, role, component,
parameters);
}
throw new IllegalStateException("Unknown task type: " + role);
}
}

View File

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;
import java.io.IOException;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This class contains all the logic to create an instance
* of a {@link Service} object for PyTorch.
* Please note that currently, only single-node (non-distributed)
* support is implemented for PyTorch.
*/
public class PyTorchServiceSpec extends AbstractServiceSpec {
private static final Logger LOG =
LoggerFactory.getLogger(PyTorchServiceSpec.class);
//this field is needed in the future!
private final PyTorchRunJobParameters pyTorchParameters;
public PyTorchServiceSpec(PyTorchRunJobParameters parameters,
ClientContext clientContext, FileSystemOperations fsOperations,
PyTorchLaunchCommandFactory launchCommandFactory, Localizer localizer) {
super(parameters, clientContext, fsOperations, launchCommandFactory,
localizer);
this.pyTorchParameters = parameters;
}
@Override
public ServiceWrapper create() throws IOException {
LOG.info("Creating PyTorch service spec");
ServiceWrapper serviceWrapper = createServiceSpecWrapper();
if (parameters.getNumWorkers() > 0) {
addWorkerComponents(serviceWrapper, Framework.PYTORCH);
}
// After all components added, handle quicklinks
handleQuicklinks(serviceWrapper.getService());
return serviceWrapper;
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;
import java.io.IOException;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Launch command implementation for PyTorch components.
*/
public class PyTorchWorkerLaunchCommand extends AbstractLaunchCommand {
private static final Logger LOG =
LoggerFactory.getLogger(PyTorchWorkerLaunchCommand.class);
private final Configuration yarnConfig;
private final boolean distributed;
private final int numberOfWorkers;
private final String name;
private final Role role;
private final String launchCommand;
public PyTorchWorkerLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
Role role, Component component,
PyTorchRunJobParameters parameters,
Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, component, parameters, role.getName());
this.role = role;
this.name = parameters.getName();
this.distributed = parameters.isDistributed();
this.numberOfWorkers = parameters.getNumWorkers();
this.yarnConfig = yarnConfig;
logReceivedParameters();
this.launchCommand = parameters.getWorkerLaunchCmd();
if (StringUtils.isEmpty(this.launchCommand)) {
throw new IllegalArgumentException("LaunchCommand must not be null " +
"or empty!");
}
}
private void logReceivedParameters() {
if (this.numberOfWorkers <= 0) {
LOG.warn("Received number of workers: {}", this.numberOfWorkers);
}
}
@Override
public String generateLaunchScript() throws IOException {
LaunchScriptBuilder builder = getBuilder();
return builder
.withLaunchCommand(createLaunchCommand())
.build();
}
@Override
public String createLaunchCommand() {
if (SubmarineLogs.isVerbose()) {
LOG.info("PyTorch Worker command =[" + launchCommand + "]");
}
return launchCommand + '\n';
}
}

View File

@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to generate PyTorch launch commands.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
import java.io.IOException;
/**
* Component implementation for Worker process of PyTorch.
*/
public class PyTorchWorkerComponent extends AbstractComponent {
public PyTorchWorkerComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
PyTorchRunJobParameters parameters, Role role,
PyTorchLaunchCommandFactory launchCommandFactory,
Configuration yarnConfig) {
super(fsOperations, remoteDirectoryManager, parameters, role,
yarnConfig, launchCommandFactory);
}
@Override
public Component createComponent() throws IOException {
return createComponentInternal();
}
}

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to generate
* PyTorch Native Service components.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to generate
* PyTorch-related Native Service runtime artifacts.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;

View File

@ -20,7 +20,7 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.common.Envs;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import java.util.Map;
@ -35,10 +35,10 @@ public final class TensorFlowCommons {
}
public static void addCommonEnvironments(Component component,
TaskType taskType) {
Role role) {
Map<String, String> envs = component.getConfiguration().getEnv();
envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
envs.put(Envs.TASK_TYPE_ENV, taskType.name());
envs.put(Envs.TASK_TYPE_ENV, role.getName());
}
public static String getUserName() {
@ -49,8 +49,8 @@ public final class TensorFlowCommons {
return yarnConfig.get("hadoop.registry.dns.domain-name");
}
public static String getScriptFileName(TaskType taskType) {
return "run-" + taskType.name() + ".sh";
public static String getScriptFileName(Role role) {
return "run-" + role.getName() + ".sh";
}
public static String getTFConfigEnv(String componentName, int nWorkers,

View File

@ -16,39 +16,24 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpec;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
/**
* This class contains all the logic to create an instance
@ -56,42 +41,34 @@ import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handle
* Worker,PS and Tensorboard components are added to the Service
* based on the value of the received {@link RunJobParameters}.
*/
public class TensorFlowServiceSpec implements ServiceSpec {
public class TensorFlowServiceSpec extends AbstractServiceSpec {
private static final Logger LOG =
LoggerFactory.getLogger(TensorFlowServiceSpec.class);
private final TensorFlowRunJobParameters tensorFlowParameters;
private final RemoteDirectoryManager remoteDirectoryManager;
private final RunJobParameters parameters;
private final Configuration yarnConfig;
private final FileSystemOperations fsOperations;
private final LaunchCommandFactory launchCommandFactory;
private final Localizer localizer;
public TensorFlowServiceSpec(RunJobParameters parameters,
public TensorFlowServiceSpec(TensorFlowRunJobParameters parameters,
ClientContext clientContext, FileSystemOperations fsOperations,
LaunchCommandFactory launchCommandFactory, Localizer localizer) {
this.parameters = parameters;
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
this.yarnConfig = clientContext.getYarnConfig();
this.fsOperations = fsOperations;
this.launchCommandFactory = launchCommandFactory;
this.localizer = localizer;
TensorFlowLaunchCommandFactory launchCommandFactory,
Localizer localizer) {
super(parameters, clientContext, fsOperations, launchCommandFactory,
localizer);
this.tensorFlowParameters = parameters;
}
@Override
public ServiceWrapper create() throws IOException {
LOG.info("Creating TensorFlow service spec");
ServiceWrapper serviceWrapper = createServiceSpecWrapper();
if (parameters.getNumWorkers() > 0) {
addWorkerComponents(serviceWrapper);
if (tensorFlowParameters.getNumWorkers() > 0) {
addWorkerComponents(serviceWrapper, Framework.TENSORFLOW);
}
if (parameters.getNumPS() > 0) {
if (tensorFlowParameters.getNumPS() > 0) {
addPsComponent(serviceWrapper);
}
if (parameters.isTensorboardEnabled()) {
if (tensorFlowParameters.isTensorboardEnabled()) {
createTensorBoardComponent(serviceWrapper);
}
@ -101,103 +78,23 @@ public class TensorFlowServiceSpec implements ServiceSpec {
return serviceWrapper;
}
private ServiceWrapper createServiceSpecWrapper() throws IOException {
Service serviceSpec = new Service();
serviceSpec.setName(parameters.getName());
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
.create(fsOperations, remoteDirectoryManager, parameters);
if (kerberosPrincipal != null) {
serviceSpec.setKerberosPrincipal(kerberosPrincipal);
}
handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
localizer.handleLocalizations(serviceSpec);
return new ServiceWrapper(serviceSpec);
}
private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
throws IOException {
TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
remoteDirectoryManager, parameters,
(TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
serviceWrapper.addComponent(tbComponent);
addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
tbComponent.getTensorboardLink());
}
private static void addQuicklink(Service serviceSpec, String label,
String link) {
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
if (quicklinks == null) {
quicklinks = new HashMap<>();
serviceSpec.setQuicklinks(quicklinks);
}
if (SubmarineLogs.isVerbose()) {
LOG.info("Added quicklink, " + label + "=" + link);
}
quicklinks.put(label, link);
}
private void handleQuicklinks(Service serviceSpec)
throws IOException {
List<Quicklink> quicklinks = parameters.getQuicklinks();
if (quicklinks != null && !quicklinks.isEmpty()) {
for (Quicklink ql : quicklinks) {
// Make sure it is a valid instance name
String instanceName = ql.getComponentInstanceName();
boolean found = false;
for (Component comp : serviceSpec.getComponents()) {
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
String possibleInstanceName = comp.getName() + "-" + i;
if (possibleInstanceName.equals(instanceName)) {
found = true;
break;
}
}
}
if (!found) {
throw new IOException(
"Couldn't find a component instance = " + instanceName
+ " while adding quicklink");
}
String link = ql.getProtocol()
+ YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
getUserName(), getDNSDomain(yarnConfig), ql.getPort());
addQuicklink(serviceSpec, ql.getLabel(), link);
}
}
}
// Handle worker and primary_worker.
private void addWorkerComponents(ServiceWrapper serviceWrapper)
throws IOException {
addWorkerComponent(serviceWrapper, parameters, TaskType.PRIMARY_WORKER);
if (parameters.getNumWorkers() > 1) {
addWorkerComponent(serviceWrapper, parameters, TaskType.WORKER);
}
}
private void addWorkerComponent(ServiceWrapper serviceWrapper,
RunJobParameters parameters, TaskType taskType) throws IOException {
serviceWrapper.addComponent(
new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
parameters, taskType, launchCommandFactory, yarnConfig));
}
private void addPsComponent(ServiceWrapper serviceWrapper)
throws IOException {
serviceWrapper.addComponent(
new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
launchCommandFactory, parameters, yarnConfig));
(TensorFlowLaunchCommandFactory) launchCommandFactory,
parameters, yarnConfig));
}
}

View File

@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.slf4j.Logger;
@ -37,9 +37,9 @@ public class TensorBoardLaunchCommand extends AbstractLaunchCommand {
private final String checkpointPath;
public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters)
Role role, Component component, RunJobParameters parameters)
throws IOException {
super(hadoopEnvSetup, taskType, component, parameters);
super(hadoopEnvSetup, component, parameters, role.getName());
Objects.requireNonNull(parameters.getCheckpointPath(),
"CheckpointPath must not be null as it is part "
+ "of the tensorboard command!");

View File

@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
@ -28,6 +28,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Objects;
/**
* Launch command implementation for
@ -41,13 +42,16 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
private final int numberOfWorkers;
private final int numberOfPS;
private final String name;
private final TaskType taskType;
private final Role role;
TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters,
Role role, Component component,
TensorFlowRunJobParameters parameters,
Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, taskType, component, parameters);
this.taskType = taskType;
super(hadoopEnvSetup, component, parameters,
role != null ? role.getName(): "");
Objects.requireNonNull(role, "TensorFlowRole must not be null!");
this.role = role;
this.name = parameters.getName();
this.distributed = parameters.isDistributed();
this.numberOfWorkers = parameters.getNumWorkers();
@ -72,7 +76,7 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
// When distributed training is required
if (distributed) {
String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
taskType.getComponentName(), numberOfWorkers,
role.getComponentName(), numberOfWorkers,
numberOfPS, name,
TensorFlowCommons.getUserName(),
TensorFlowCommons.getDNSDomain(yarnConfig));

View File

@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.slf4j.Logger;
@ -37,9 +37,10 @@ public class TensorFlowPsLaunchCommand extends TensorFlowLaunchCommand {
private final String launchCommand;
public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters,
Role role, Component component,
TensorFlowRunJobParameters parameters,
Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
super(hadoopEnvSetup, role, component, parameters, yarnConfig);
this.launchCommand = parameters.getPSLaunchCmd();
if (StringUtils.isEmpty(this.launchCommand)) {

View File

@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.slf4j.Logger;
@ -37,10 +37,10 @@ public class TensorFlowWorkerLaunchCommand extends TensorFlowLaunchCommand {
private final String launchCommand;
public TensorFlowWorkerLaunchCommand(
HadoopEnvironmentSetup hadoopEnvSetup, TaskType taskType,
Component component, RunJobParameters parameters,
HadoopEnvironmentSetup hadoopEnvSetup, Role role,
Component component, TensorFlowRunJobParameters parameters,
Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
super(hadoopEnvSetup, role, component, parameters, yarnConfig);
this.launchCommand = parameters.getWorkerLaunchCmd();
if (StringUtils.isEmpty(this.launchCommand)) {

View File

@ -19,13 +19,14 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.compone
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -54,35 +55,38 @@ public class TensorBoardComponent extends AbstractComponent {
public TensorBoardComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters,
LaunchCommandFactory launchCommandFactory,
TensorFlowLaunchCommandFactory launchCommandFactory,
Configuration yarnConfig) {
super(fsOperations, remoteDirectoryManager, parameters,
TaskType.TENSORBOARD, yarnConfig, launchCommandFactory);
TensorFlowRole.TENSORBOARD, yarnConfig, launchCommandFactory);
}
@Override
public Component createComponent() throws IOException {
Objects.requireNonNull(parameters.getTensorboardResource(),
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) this.parameters;
Objects.requireNonNull(tensorFlowParams.getTensorboardResource(),
"TensorBoard resource must not be null!");
Component component = new Component();
component.setName(taskType.getComponentName());
component.setName(role.getComponentName());
component.setNumberOfContainers(1L);
component.setRestartPolicy(RestartPolicyEnum.NEVER);
component.setResource(convertYarnResourceToServiceResource(
parameters.getTensorboardResource()));
tensorFlowParams.getTensorboardResource()));
if (parameters.getTensorboardDockerImage() != null) {
if (tensorFlowParams.getTensorboardDockerImage() != null) {
component.setArtifact(
getDockerArtifact(parameters.getTensorboardDockerImage()));
getDockerArtifact(tensorFlowParams.getTensorboardDockerImage()));
}
addCommonEnvironments(component, taskType);
addCommonEnvironments(component, role);
generateLaunchCommand(component);
tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
parameters.getName(),
taskType.getComponentName() + "-" + 0, getUserName(),
role.getComponentName() + "-" + 0, getUserName(),
getDNSDomain(yarnConfig), DEFAULT_PORT);
LOG.info("Link to tensorboard:" + tensorboardLink);

Some files were not shown because too many files have changed in this diff Show More