SUBMARINE-52. [SUBMARINE-14] Generate Service spec + launch script for single-node PyTorch learning job. Contributed by Szilard Nemeth.
This commit is contained in:
parent
64c7f36ab1
commit
36267b6f7c
|
@ -0,0 +1,77 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
|
||||
ARG PYTHON_VERSION=3.6
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
cmake \
|
||||
git \
|
||||
curl \
|
||||
vim \
|
||||
ca-certificates \
|
||||
libjpeg-dev \
|
||||
libpng-dev \
|
||||
wget &&\
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
chmod +x ~/miniconda.sh && \
|
||||
~/miniconda.sh -b -p /opt/conda && \
|
||||
rm ~/miniconda.sh && \
|
||||
/opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include cython typing && \
|
||||
/opt/conda/bin/conda install -y -c pytorch magma-cuda100 && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
RUN pip install ninja
|
||||
# This must be done before pip so that requirements.txt is available
|
||||
WORKDIR /opt/pytorch
|
||||
RUN git clone https://github.com/pytorch/pytorch.git
|
||||
WORKDIR pytorch
|
||||
RUN git submodule update --init
|
||||
RUN TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1 7.0+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
|
||||
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
|
||||
pip install -v .
|
||||
|
||||
WORKDIR /opt/pytorch
|
||||
RUN git clone https://github.com/pytorch/vision.git && cd vision && pip install -v .
|
||||
|
||||
WORKDIR /
|
||||
# Install Hadoop
|
||||
ENV HADOOP_VERSION="3.1.2"
|
||||
RUN wget https://archive.apache.org/dist/hadoop/common/hadoop-${HADOOP_VERSION}/hadoop-${HADOOP_VERSION}.tar.gz
|
||||
RUN tar zxf hadoop-${HADOOP_VERSION}.tar.gz
|
||||
RUN ln -s hadoop-${HADOOP_VERSION} hadoop-current
|
||||
RUN rm hadoop-${HADOOP_VERSION}.tar.gz
|
||||
|
||||
ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64
|
||||
RUN echo "$LOG_TAG Install java8" && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends openjdk-8-jdk && \
|
||||
apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN echo "Install python related packages" && \
|
||||
pip --no-cache-dir install Pillow h5py ipykernel jupyter matplotlib numpy pandas scipy sklearn && \
|
||||
python -m ipykernel.kernelspec
|
||||
|
||||
# Set the locale to fix bash warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends locales && \
|
||||
apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
RUN locale-gen en_US.UTF-8
|
||||
|
||||
|
||||
WORKDIR /workspace
|
||||
RUN chmod -R a+w /workspace
|
|
@ -0,0 +1,30 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
echo "Building base images"
|
||||
|
||||
set -e
|
||||
|
||||
cd base/ubuntu-16.04
|
||||
|
||||
docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu-base:0.0.1
|
||||
|
||||
echo "Finished building base images"
|
||||
|
||||
cd ../../with-cifar10-models/ubuntu-16.04
|
||||
|
||||
docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu:0.0.1
|
|
@ -0,0 +1,354 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Training a Classifier
|
||||
=====================
|
||||
|
||||
This is it. You have seen how to define neural networks, compute loss and make
|
||||
updates to the weights of the network.
|
||||
|
||||
Now you might be thinking,
|
||||
|
||||
What about data?
|
||||
----------------
|
||||
|
||||
Generally, when you have to deal with image, text, audio or video data,
|
||||
you can use standard python packages that load data into a numpy array.
|
||||
Then you can convert this array into a ``torch.*Tensor``.
|
||||
|
||||
- For images, packages such as Pillow, OpenCV are useful
|
||||
- For audio, packages such as scipy and librosa
|
||||
- For text, either raw Python or Cython based loading, or NLTK and
|
||||
SpaCy are useful
|
||||
|
||||
Specifically for vision, we have created a package called
|
||||
``torchvision``, that has data loaders for common datasets such as
|
||||
Imagenet, CIFAR10, MNIST, etc. and data transformers for images, viz.,
|
||||
``torchvision.datasets`` and ``torch.utils.data.DataLoader``.
|
||||
|
||||
This provides a huge convenience and avoids writing boilerplate code.
|
||||
|
||||
For this tutorial, we will use the CIFAR10 dataset.
|
||||
It has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,
|
||||
‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. The images in CIFAR-10 are of
|
||||
size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.
|
||||
|
||||
.. figure:: /_static/img/cifar10.png
|
||||
:alt: cifar10
|
||||
|
||||
cifar10
|
||||
|
||||
|
||||
Training an image classifier
|
||||
----------------------------
|
||||
|
||||
We will do the following steps in order:
|
||||
|
||||
1. Load and normalizing the CIFAR10 training and test datasets using
|
||||
``torchvision``
|
||||
2. Define a Convolutional Neural Network
|
||||
3. Define a loss function
|
||||
4. Train the network on the training data
|
||||
5. Test the network on the test data
|
||||
|
||||
1. Loading and normalizing CIFAR10
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using ``torchvision``, it’s extremely easy to load CIFAR10.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
########################################################################
|
||||
# The output of torchvision datasets are PILImage images of range [0, 1].
|
||||
# We transform them to Tensors of normalized range [-1, 1].
|
||||
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
|
||||
download=True, transform=transform)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
|
||||
shuffle=True, num_workers=2)
|
||||
|
||||
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
|
||||
download=True, transform=transform)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
|
||||
shuffle=False, num_workers=2)
|
||||
|
||||
classes = ('plane', 'car', 'bird', 'cat',
|
||||
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
||||
|
||||
########################################################################
|
||||
# Let us show some of the training images, for fun.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
# functions to show an image
|
||||
|
||||
|
||||
def imshow(img):
|
||||
img = img / 2 + 0.5 # unnormalize
|
||||
npimg = img.numpy()
|
||||
plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
||||
plt.show()
|
||||
|
||||
|
||||
# get some random training images
|
||||
dataiter = iter(trainloader)
|
||||
images, labels = dataiter.next()
|
||||
|
||||
# show images
|
||||
imshow(torchvision.utils.make_grid(images))
|
||||
# print labels
|
||||
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
|
||||
|
||||
########################################################################
|
||||
# 2. Define a Convolutional Neural Network
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# Copy the neural network from the Neural Networks section before and modify it to
|
||||
# take 3-channel images (instead of 1-channel images as it was defined).
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 6, 5)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
||||
self.fc2 = nn.Linear(120, 84)
|
||||
self.fc3 = nn.Linear(84, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(F.relu(self.conv1(x)))
|
||||
x = self.pool(F.relu(self.conv2(x)))
|
||||
x = x.view(-1, 16 * 5 * 5)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
net = Net()
|
||||
|
||||
########################################################################
|
||||
# 3. Define a Loss function and optimizer
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# Let's use a Classification Cross-Entropy loss and SGD with momentum.
|
||||
|
||||
import torch.optim as optim
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
########################################################################
|
||||
# 4. Train the network
|
||||
# ^^^^^^^^^^^^^^^^^^^^
|
||||
#
|
||||
# This is when things start to get interesting.
|
||||
# We simply have to loop over our data iterator, and feed the inputs to the
|
||||
# network and optimize.
|
||||
|
||||
for epoch in range(2): # loop over the dataset multiple times
|
||||
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(trainloader, 0):
|
||||
# get the inputs
|
||||
inputs, labels = data
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
running_loss += loss.item()
|
||||
if i % 2000 == 1999: # print every 2000 mini-batches
|
||||
print('[%d, %5d] loss: %.3f' %
|
||||
(epoch + 1, i + 1, running_loss / 2000))
|
||||
running_loss = 0.0
|
||||
|
||||
print('Finished Training')
|
||||
|
||||
########################################################################
|
||||
# 5. Test the network on the test data
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
#
|
||||
# We have trained the network for 2 passes over the training dataset.
|
||||
# But we need to check if the network has learnt anything at all.
|
||||
#
|
||||
# We will check this by predicting the class label that the neural network
|
||||
# outputs, and checking it against the ground-truth. If the prediction is
|
||||
# correct, we add the sample to the list of correct predictions.
|
||||
#
|
||||
# Okay, first step. Let us display an image from the test set to get familiar.
|
||||
|
||||
dataiter = iter(testloader)
|
||||
images, labels = dataiter.next()
|
||||
|
||||
# print images
|
||||
imshow(torchvision.utils.make_grid(images))
|
||||
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
|
||||
|
||||
########################################################################
|
||||
# Okay, now let us see what the neural network thinks these examples above are:
|
||||
|
||||
outputs = net(images)
|
||||
|
||||
########################################################################
|
||||
# The outputs are energies for the 10 classes.
|
||||
# The higher the energy for a class, the more the network
|
||||
# thinks that the image is of the particular class.
|
||||
# So, let's get the index of the highest energy:
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
|
||||
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
|
||||
for j in range(4)))
|
||||
|
||||
########################################################################
|
||||
# The results seem pretty good.
|
||||
#
|
||||
# Let us look at how the network performs on the whole dataset.
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for data in testloader:
|
||||
images, labels = data
|
||||
outputs = net(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
print('Accuracy of the network on the 10000 test images: %d %%' % (
|
||||
100 * correct / total))
|
||||
|
||||
########################################################################
|
||||
# That looks waaay better than chance, which is 10% accuracy (randomly picking
|
||||
# a class out of 10 classes).
|
||||
# Seems like the network learnt something.
|
||||
#
|
||||
# Hmmm, what are the classes that performed well, and the classes that did
|
||||
# not perform well:
|
||||
|
||||
class_correct = list(0. for i in range(10))
|
||||
class_total = list(0. for i in range(10))
|
||||
with torch.no_grad():
|
||||
for data in testloader:
|
||||
images, labels = data
|
||||
outputs = net(images)
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
c = (predicted == labels).squeeze()
|
||||
for i in range(4):
|
||||
label = labels[i]
|
||||
class_correct[label] += c[i].item()
|
||||
class_total[label] += 1
|
||||
|
||||
for i in range(10):
|
||||
print('Accuracy of %5s : %2d %%' % (
|
||||
classes[i], 100 * class_correct[i] / class_total[i]))
|
||||
|
||||
########################################################################
|
||||
# Okay, so what next?
|
||||
#
|
||||
# How do we run these neural networks on the GPU?
|
||||
#
|
||||
# Training on GPU
|
||||
# ----------------
|
||||
# Just like how you transfer a Tensor onto the GPU, you transfer the neural
|
||||
# net onto the GPU.
|
||||
#
|
||||
# Let's first define our device as the first visible cuda device if we have
|
||||
# CUDA available:
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Assuming that we are on a CUDA machine, this should print a CUDA device:
|
||||
|
||||
print(device)
|
||||
|
||||
########################################################################
|
||||
# The rest of this section assumes that ``device`` is a CUDA device.
|
||||
#
|
||||
# Then these methods will recursively go over all modules and convert their
|
||||
# parameters and buffers to CUDA tensors:
|
||||
#
|
||||
# .. code:: python
|
||||
#
|
||||
# net.to(device)
|
||||
#
|
||||
#
|
||||
# Remember that you will have to send the inputs and targets at every step
|
||||
# to the GPU too:
|
||||
#
|
||||
# .. code:: python
|
||||
#
|
||||
# inputs, labels = inputs.to(device), labels.to(device)
|
||||
#
|
||||
# Why dont I notice MASSIVE speedup compared to CPU? Because your network
|
||||
# is realllly small.
|
||||
#
|
||||
# **Exercise:** Try increasing the width of your network (argument 2 of
|
||||
# the first ``nn.Conv2d``, and argument 1 of the second ``nn.Conv2d`` –
|
||||
# they need to be the same number), see what kind of speedup you get.
|
||||
#
|
||||
# **Goals achieved**:
|
||||
#
|
||||
# - Understanding PyTorch's Tensor library and neural networks at a high level.
|
||||
# - Train a small neural network to classify images
|
||||
#
|
||||
# Training on multiple GPUs
|
||||
# -------------------------
|
||||
# If you want to see even more MASSIVE speedup using all of your GPUs,
|
||||
# please check out :doc:`data_parallel_tutorial`.
|
||||
#
|
||||
# Where do I go next?
|
||||
# -------------------
|
||||
#
|
||||
# - :doc:`Train neural nets to play video games </intermediate/reinforcement_q_learning>`
|
||||
# - `Train a state-of-the-art ResNet network on imagenet`_
|
||||
# - `Train a face generator using Generative Adversarial Networks`_
|
||||
# - `Train a word-level language model using Recurrent LSTM networks`_
|
||||
# - `More examples`_
|
||||
# - `More tutorials`_
|
||||
# - `Discuss PyTorch on the Forums`_
|
||||
# - `Chat with other users on Slack`_
|
||||
#
|
||||
# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/master/imagenet
|
||||
# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/master/dcgan
|
||||
# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/master/word_language_model
|
||||
# .. _More examples: https://github.com/pytorch/examples
|
||||
# .. _More tutorials: https://github.com/pytorch/tutorials
|
||||
# .. _Discuss PyTorch on the Forums: https://discuss.pytorch.org/
|
||||
# .. _Chat with other users on Slack: https://pytorch.slack.com/messages/beginner/
|
||||
|
||||
# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%
|
||||
del dataiter
|
||||
# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%
|
|
@ -0,0 +1,21 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM pytorch-latest-gpu-base:0.0.1
|
||||
|
||||
RUN mkdir -p /test/data
|
||||
RUN chmod -R 777 /test
|
||||
ADD cifar10_tutorial.py /test/cifar10_tutorial.py
|
|
@ -15,6 +15,7 @@
|
|||
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
||||
|
|
|
@ -59,4 +59,6 @@ public class CliConstants {
|
|||
public static final String DISTRIBUTE_KEYTAB = "distribute_keytab";
|
||||
public static final String YAML_CONFIG = "f";
|
||||
public static final String INSECURE_CLUSTER = "insecure";
|
||||
|
||||
public static final String FRAMEWORK = "framework";
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ package org.apache.hadoop.yarn.submarine.client.cli;
|
|||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.hadoop.security.UserGroupInformation;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException;
|
||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||
import org.slf4j.Logger;
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||
|
||||
/**
|
||||
* Represents a Submarine command.
|
||||
*/
|
||||
public enum Command {
|
||||
RUN_JOB, SHOW_JOB
|
||||
}
|
|
@ -37,7 +37,7 @@ public class ShowJobCli extends AbstractCli {
|
|||
private static final Logger LOG = LoggerFactory.getLogger(ShowJobCli.class);
|
||||
|
||||
private Options options;
|
||||
private ShowJobParameters parameters = new ShowJobParameters();
|
||||
private ParametersHolder parametersHolder;
|
||||
|
||||
public ShowJobCli(ClientContext cliContext) {
|
||||
super(cliContext);
|
||||
|
@ -62,9 +62,9 @@ public class ShowJobCli extends AbstractCli {
|
|||
CommandLine cli;
|
||||
try {
|
||||
cli = parser.parse(options, args);
|
||||
ParametersHolder parametersHolder = ParametersHolder
|
||||
.createWithCmdLine(cli);
|
||||
parameters.updateParameters(parametersHolder, clientContext);
|
||||
parametersHolder = ParametersHolder
|
||||
.createWithCmdLine(cli, Command.SHOW_JOB);
|
||||
parametersHolder.updateParameters(clientContext);
|
||||
} catch (ParseException e) {
|
||||
printUsages();
|
||||
}
|
||||
|
@ -97,7 +97,7 @@ public class ShowJobCli extends AbstractCli {
|
|||
|
||||
Map<String, String> jobInfo = null;
|
||||
try {
|
||||
jobInfo = storage.getJobInfoByName(parameters.getName());
|
||||
jobInfo = storage.getJobInfoByName(getParameters().getName());
|
||||
} catch (IOException e) {
|
||||
LOG.error("Failed to retrieve job info", e);
|
||||
throw e;
|
||||
|
@ -108,7 +108,7 @@ public class ShowJobCli extends AbstractCli {
|
|||
|
||||
@VisibleForTesting
|
||||
public ShowJobParameters getParameters() {
|
||||
return parameters;
|
||||
return (ShowJobParameters) parametersHolder.getParameters();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.param;
|
||||
|
||||
/**
|
||||
* Represents the source of configuration.
|
||||
*/
|
||||
public enum ConfigType {
|
||||
YAML, CLI
|
||||
}
|
|
@ -20,8 +20,12 @@ import com.google.common.collect.ImmutableSet;
|
|||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import org.apache.commons.cli.CommandLine;
|
||||
import org.apache.commons.cli.ParseException;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.Command;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Roles;
|
||||
|
@ -29,15 +33,22 @@ import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Scheduling;
|
|||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Security;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.TensorBoard;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli.YAML_PARSE_FAILED;
|
||||
|
||||
/**
|
||||
* This class acts as a wrapper of {@code CommandLine} values along with
|
||||
* YAML configuration values.
|
||||
|
@ -52,17 +63,110 @@ public final class ParametersHolder {
|
|||
private static final Logger LOG =
|
||||
LoggerFactory.getLogger(ParametersHolder.class);
|
||||
|
||||
public static final String SUPPORTED_FRAMEWORKS_MESSAGE =
|
||||
"TensorFlow and PyTorch are the only supported frameworks for now!";
|
||||
public static final String SUPPORTED_COMMANDS_MESSAGE =
|
||||
"'Show job' and 'run job' are the only supported commands for now!";
|
||||
|
||||
|
||||
|
||||
private final CommandLine parsedCommandLine;
|
||||
private final Map<String, String> yamlStringConfigs;
|
||||
private final Map<String, List<String>> yamlListConfigs;
|
||||
private final ImmutableSet onlyDefinedWithCliArgs = ImmutableSet.of(
|
||||
private final ConfigType configType;
|
||||
private Command command;
|
||||
private final Set onlyDefinedWithCliArgs = ImmutableSet.of(
|
||||
CliConstants.VERBOSE);
|
||||
private final Framework framework;
|
||||
private final BaseParameters parameters;
|
||||
|
||||
private ParametersHolder(CommandLine parsedCommandLine,
|
||||
YamlConfigFile yamlConfig) {
|
||||
YamlConfigFile yamlConfig, ConfigType configType, Command command)
|
||||
throws ParseException, YarnException {
|
||||
this.parsedCommandLine = parsedCommandLine;
|
||||
this.yamlStringConfigs = initStringConfigValues(yamlConfig);
|
||||
this.yamlListConfigs = initListConfigValues(yamlConfig);
|
||||
this.configType = configType;
|
||||
this.command = command;
|
||||
this.framework = determineFrameworkType();
|
||||
this.ensureOnlyValidSectionsAreDefined(yamlConfig);
|
||||
this.parameters = createParameters();
|
||||
}
|
||||
|
||||
private BaseParameters createParameters() {
|
||||
if (command == Command.RUN_JOB) {
|
||||
if (framework == Framework.TENSORFLOW) {
|
||||
return new TensorFlowRunJobParameters();
|
||||
} else if (framework == Framework.PYTORCH) {
|
||||
return new PyTorchRunJobParameters();
|
||||
} else {
|
||||
throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
|
||||
}
|
||||
} else if (command == Command.SHOW_JOB) {
|
||||
return new ShowJobParameters();
|
||||
} else {
|
||||
throw new UnsupportedOperationException(SUPPORTED_COMMANDS_MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
private void ensureOnlyValidSectionsAreDefined(YamlConfigFile yamlConfig) {
|
||||
if (isCommandRunJob() && isFrameworkPyTorch() &&
|
||||
isPsSectionDefined(yamlConfig)) {
|
||||
throw new YamlParseException(
|
||||
"PS section should not be defined when PyTorch " +
|
||||
"is the selected framework!");
|
||||
}
|
||||
|
||||
if (isCommandRunJob() && isFrameworkPyTorch() &&
|
||||
isTensorboardSectionDefined(yamlConfig)) {
|
||||
throw new YamlParseException(
|
||||
"TensorBoard section should not be defined when PyTorch " +
|
||||
"is the selected framework!");
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isCommandRunJob() {
|
||||
return command == Command.RUN_JOB;
|
||||
}
|
||||
|
||||
private boolean isFrameworkPyTorch() {
|
||||
return framework == Framework.PYTORCH;
|
||||
}
|
||||
|
||||
private boolean isPsSectionDefined(YamlConfigFile yamlConfig) {
|
||||
return yamlConfig != null &&
|
||||
yamlConfig.getRoles() != null &&
|
||||
yamlConfig.getRoles().getPs() != null;
|
||||
}
|
||||
|
||||
private boolean isTensorboardSectionDefined(YamlConfigFile yamlConfig) {
|
||||
return yamlConfig != null &&
|
||||
yamlConfig.getTensorBoard() != null;
|
||||
}
|
||||
|
||||
private Framework determineFrameworkType()
|
||||
throws ParseException, YarnException {
|
||||
if (!isCommandRunJob()) {
|
||||
return null;
|
||||
}
|
||||
String frameworkStr = getOptionValue(CliConstants.FRAMEWORK);
|
||||
if (frameworkStr == null) {
|
||||
LOG.info("Framework is not defined in config, falling back to " +
|
||||
"TensorFlow as a default.");
|
||||
return Framework.TENSORFLOW;
|
||||
}
|
||||
Framework framework = Framework.parseByValue(frameworkStr);
|
||||
if (framework == null) {
|
||||
if (getConfigType() == ConfigType.CLI) {
|
||||
throw new ParseException("Failed to parse Framework type! "
|
||||
+ "Valid values are: " + Framework.getValues());
|
||||
} else {
|
||||
throw new YamlParseException(YAML_PARSE_FAILED +
|
||||
", framework should is defined, but it has an invalid value! " +
|
||||
"Valid values are: " + Framework.getValues());
|
||||
}
|
||||
}
|
||||
return framework;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -108,6 +212,8 @@ public final class ParametersHolder {
|
|||
private void initGenericConfigs(YamlConfigFile yamlConfig,
|
||||
Map<String, String> yamlConfigs) {
|
||||
yamlConfigs.put(CliConstants.NAME, yamlConfig.getSpec().getName());
|
||||
yamlConfigs.put(CliConstants.FRAMEWORK,
|
||||
yamlConfig.getSpec().getFramework());
|
||||
|
||||
Configs configs = yamlConfig.getConfigs();
|
||||
yamlConfigs.put(CliConstants.INPUT_PATH, configs.getInputPath());
|
||||
|
@ -178,13 +284,15 @@ public final class ParametersHolder {
|
|||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static ParametersHolder createWithCmdLine(CommandLine cli) {
|
||||
return new ParametersHolder(cli, null);
|
||||
public static ParametersHolder createWithCmdLine(CommandLine cli,
|
||||
Command command) throws ParseException, YarnException {
|
||||
return new ParametersHolder(cli, null, ConfigType.CLI, command);
|
||||
}
|
||||
|
||||
public static ParametersHolder createWithCmdLineAndYaml(CommandLine cli,
|
||||
YamlConfigFile yamlConfig) {
|
||||
return new ParametersHolder(cli, yamlConfig);
|
||||
YamlConfigFile yamlConfig, Command command) throws ParseException,
|
||||
YarnException {
|
||||
return new ParametersHolder(cli, yamlConfig, ConfigType.YAML, command);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -193,7 +301,7 @@ public final class ParametersHolder {
|
|||
* @param option Name of the config.
|
||||
* @return The value of the config
|
||||
*/
|
||||
String getOptionValue(String option) throws YarnException {
|
||||
public String getOptionValue(String option) throws YarnException {
|
||||
ensureConfigIsDefinedOnce(option, true);
|
||||
if (onlyDefinedWithCliArgs.contains(option) ||
|
||||
parsedCommandLine.hasOption(option)) {
|
||||
|
@ -208,7 +316,7 @@ public final class ParametersHolder {
|
|||
* @param option Name of the config.
|
||||
* @return The values of the config
|
||||
*/
|
||||
List<String> getOptionValues(String option) throws YarnException {
|
||||
public List<String> getOptionValues(String option) throws YarnException {
|
||||
ensureConfigIsDefinedOnce(option, false);
|
||||
if (onlyDefinedWithCliArgs.contains(option) ||
|
||||
parsedCommandLine.hasOption(option)) {
|
||||
|
@ -285,7 +393,7 @@ public final class ParametersHolder {
|
|||
* @return true, if the option is found in the CLI args or in the YAML config,
|
||||
* false otherwise.
|
||||
*/
|
||||
boolean hasOption(String option) {
|
||||
public boolean hasOption(String option) {
|
||||
if (onlyDefinedWithCliArgs.contains(option)) {
|
||||
boolean value = parsedCommandLine.hasOption(option);
|
||||
if (LOG.isDebugEnabled()) {
|
||||
|
@ -312,4 +420,21 @@ public final class ParametersHolder {
|
|||
"from YAML configuration.", result, option);
|
||||
return result;
|
||||
}
|
||||
|
||||
public ConfigType getConfigType() {
|
||||
return configType;
|
||||
}
|
||||
|
||||
public Framework getFramework() {
|
||||
return framework;
|
||||
}
|
||||
|
||||
public void updateParameters(ClientContext clientContext)
|
||||
throws ParseException, YarnException, IOException {
|
||||
parameters.updateParameters(this, clientContext);
|
||||
}
|
||||
|
||||
public BaseParameters getParameters() {
|
||||
return parameters;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.cli.ParseException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
/**
|
||||
* Parameters for PyTorch job.
|
||||
*/
|
||||
public class PyTorchRunJobParameters extends RunJobParameters {
|
||||
|
||||
private static final String CANNOT_BE_DEFINED_FOR_PYTORCH =
|
||||
"cannot be defined for PyTorch jobs!";
|
||||
|
||||
@Override
|
||||
public void updateParameters(ParametersHolder parametersHolder,
|
||||
ClientContext clientContext)
|
||||
throws ParseException, IOException, YarnException {
|
||||
checkArguments(parametersHolder);
|
||||
|
||||
super.updateParameters(parametersHolder, clientContext);
|
||||
|
||||
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
|
||||
this.workerParameters =
|
||||
getWorkerParameters(clientContext, parametersHolder, input);
|
||||
this.distributed = determineIfDistributed(workerParameters.getReplicas());
|
||||
executePostOperations(clientContext);
|
||||
}
|
||||
|
||||
private void checkArguments(ParametersHolder parametersHolder)
|
||||
throws YarnException, ParseException {
|
||||
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
|
||||
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||
CliConstants.N_PS));
|
||||
} else if (parametersHolder.getOptionValue(CliConstants.PS_RES) != null) {
|
||||
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||
CliConstants.PS_RES));
|
||||
} else if (parametersHolder
|
||||
.getOptionValue(CliConstants.PS_DOCKER_IMAGE) != null) {
|
||||
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||
CliConstants.PS_DOCKER_IMAGE));
|
||||
} else if (parametersHolder
|
||||
.getOptionValue(CliConstants.PS_LAUNCH_CMD) != null) {
|
||||
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||
CliConstants.PS_LAUNCH_CMD));
|
||||
} else if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
|
||||
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||
CliConstants.TENSORBOARD));
|
||||
} else if (parametersHolder
|
||||
.getOptionValue(CliConstants.TENSORBOARD_RESOURCES) != null) {
|
||||
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||
CliConstants.TENSORBOARD_RESOURCES));
|
||||
} else if (parametersHolder
|
||||
.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE) != null) {
|
||||
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||
CliConstants.TENSORBOARD_DOCKER_IMAGE));
|
||||
}
|
||||
}
|
||||
|
||||
private String getParamCannotBeDefinedErrorMessage(String cliName) {
|
||||
return String.format(
|
||||
"Parameter '%s' " + CANNOT_BE_DEFINED_FOR_PYTORCH, cliName);
|
||||
}
|
||||
|
||||
@Override
|
||||
void executePostOperations(ClientContext clientContext) throws IOException {
|
||||
// Set default job dir / saved model dir, etc.
|
||||
setDefaultDirs(clientContext);
|
||||
replacePatternsInParameters(clientContext);
|
||||
}
|
||||
|
||||
private void replacePatternsInParameters(ClientContext clientContext)
|
||||
throws IOException {
|
||||
if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
|
||||
String afterReplace =
|
||||
CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
|
||||
clientContext.getRemoteDirectoryManager());
|
||||
setWorkerLaunchCmd(afterReplace);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getLaunchCommands() {
|
||||
return Lists.newArrayList(getWorkerLaunchCmd());
|
||||
}
|
||||
|
||||
/**
|
||||
* We only support non-distributed PyTorch integration for now.
|
||||
* @param nWorkers
|
||||
* @return
|
||||
*/
|
||||
private boolean determineIfDistributed(int nWorkers) {
|
||||
return false;
|
||||
}
|
||||
}
|
|
@ -1,3 +1,19 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -12,7 +28,7 @@
|
|||
* limitations under the License. See accompanying LICENSE file.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.param;
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.CaseFormat;
|
||||
|
@ -21,7 +37,14 @@ import org.apache.hadoop.yarn.api.records.Resource;
|
|||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||
import org.yaml.snakeyaml.introspector.Property;
|
||||
import org.yaml.snakeyaml.introspector.PropertyUtils;
|
||||
|
@ -34,27 +57,15 @@ import java.util.List;
|
|||
/**
|
||||
* Parameters used to run a job
|
||||
*/
|
||||
public class RunJobParameters extends RunParameters {
|
||||
public abstract class RunJobParameters extends RunParameters {
|
||||
private String input;
|
||||
private String checkpointPath;
|
||||
|
||||
private int numWorkers;
|
||||
private int numPS;
|
||||
private Resource workerResource;
|
||||
private Resource psResource;
|
||||
private boolean tensorboardEnabled;
|
||||
private Resource tensorboardResource;
|
||||
private String tensorboardDockerImage;
|
||||
private String workerLaunchCmd;
|
||||
private String psLaunchCmd;
|
||||
private List<Quicklink> quicklinks = new ArrayList<>();
|
||||
private List<Localization> localizations = new ArrayList<>();
|
||||
|
||||
private String psDockerImage = null;
|
||||
private String workerDockerImage = null;
|
||||
|
||||
private boolean waitJobFinish = false;
|
||||
private boolean distributed = false;
|
||||
protected boolean distributed = false;
|
||||
|
||||
private boolean securityDisabled = false;
|
||||
private String keytab;
|
||||
|
@ -62,6 +73,9 @@ public class RunJobParameters extends RunParameters {
|
|||
private boolean distributeKeytab = false;
|
||||
private List<String> confPairs = new ArrayList<>();
|
||||
|
||||
RoleParameters workerParameters =
|
||||
RoleParameters.createEmpty(TensorFlowRole.WORKER);
|
||||
|
||||
@Override
|
||||
public void updateParameters(ParametersHolder parametersHolder,
|
||||
ClientContext clientContext)
|
||||
|
@ -70,34 +84,6 @@ public class RunJobParameters extends RunParameters {
|
|||
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
|
||||
String jobDir = parametersHolder.getOptionValue(
|
||||
CliConstants.CHECKPOINT_PATH);
|
||||
int nWorkers = 1;
|
||||
if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
|
||||
nWorkers = Integer.parseInt(
|
||||
parametersHolder.getOptionValue(CliConstants.N_WORKERS));
|
||||
// Only check null value.
|
||||
// Training job shouldn't ignore INPUT_PATH option
|
||||
// But if nWorkers is 0, INPUT_PATH can be ignored because
|
||||
// user can only run Tensorboard
|
||||
if (null == input && 0 != nWorkers) {
|
||||
throw new ParseException("\"--" + CliConstants.INPUT_PATH +
|
||||
"\" is absent");
|
||||
}
|
||||
}
|
||||
|
||||
int nPS = 0;
|
||||
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
|
||||
nPS = Integer.parseInt(
|
||||
parametersHolder.getOptionValue(CliConstants.N_PS));
|
||||
}
|
||||
|
||||
// Check #workers and #ps.
|
||||
// When distributed training is required
|
||||
if (nWorkers >= 2 && nPS > 0) {
|
||||
distributed = true;
|
||||
} else if (nWorkers <= 1 && nPS > 0) {
|
||||
throw new ParseException("Only specified one worker but non-zero PS, "
|
||||
+ "please double check.");
|
||||
}
|
||||
|
||||
if (parametersHolder.hasOption(CliConstants.INSECURE_CLUSTER)) {
|
||||
setSecurityDisabled(true);
|
||||
|
@ -109,46 +95,6 @@ public class RunJobParameters extends RunParameters {
|
|||
CliConstants.PRINCIPAL);
|
||||
CliUtils.doLoginIfSecure(kerberosKeytab, kerberosPrincipal);
|
||||
|
||||
workerResource = null;
|
||||
if (nWorkers > 0) {
|
||||
String workerResourceStr = parametersHolder.getOptionValue(
|
||||
CliConstants.WORKER_RES);
|
||||
if (workerResourceStr == null) {
|
||||
throw new ParseException(
|
||||
"--" + CliConstants.WORKER_RES + " is absent.");
|
||||
}
|
||||
workerResource = ResourceUtils.createResourceFromString(
|
||||
workerResourceStr,
|
||||
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
||||
}
|
||||
|
||||
Resource psResource = null;
|
||||
if (nPS > 0) {
|
||||
String psResourceStr = parametersHolder.getOptionValue(
|
||||
CliConstants.PS_RES);
|
||||
if (psResourceStr == null) {
|
||||
throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
|
||||
}
|
||||
psResource = ResourceUtils.createResourceFromString(psResourceStr,
|
||||
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
||||
}
|
||||
|
||||
boolean tensorboard = false;
|
||||
if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
|
||||
tensorboard = true;
|
||||
String tensorboardResourceStr = parametersHolder.getOptionValue(
|
||||
CliConstants.TENSORBOARD_RESOURCES);
|
||||
if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
|
||||
tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
|
||||
}
|
||||
tensorboardResource = ResourceUtils.createResourceFromString(
|
||||
tensorboardResourceStr,
|
||||
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
||||
tensorboardDockerImage = parametersHolder.getOptionValue(
|
||||
CliConstants.TENSORBOARD_DOCKER_IMAGE);
|
||||
this.setTensorboardResource(tensorboardResource);
|
||||
}
|
||||
|
||||
if (parametersHolder.hasOption(CliConstants.WAIT_JOB_FINISH)) {
|
||||
this.waitJobFinish = true;
|
||||
}
|
||||
|
@ -164,16 +110,6 @@ public class RunJobParameters extends RunParameters {
|
|||
}
|
||||
}
|
||||
|
||||
psDockerImage = parametersHolder.getOptionValue(
|
||||
CliConstants.PS_DOCKER_IMAGE);
|
||||
workerDockerImage = parametersHolder.getOptionValue(
|
||||
CliConstants.WORKER_DOCKER_IMAGE);
|
||||
|
||||
String workerLaunchCmd = parametersHolder.getOptionValue(
|
||||
CliConstants.WORKER_LAUNCH_CMD);
|
||||
String psLaunchCommand = parametersHolder.getOptionValue(
|
||||
CliConstants.PS_LAUNCH_CMD);
|
||||
|
||||
// Localizations
|
||||
List<String> localizationsStr = parametersHolder.getOptionValues(
|
||||
CliConstants.LOCALIZATION);
|
||||
|
@ -191,10 +127,6 @@ public class RunJobParameters extends RunParameters {
|
|||
.getOptionValues(CliConstants.ARG_CONF);
|
||||
|
||||
this.setInputPath(input).setCheckpointPath(jobDir)
|
||||
.setNumPS(nPS).setNumWorkers(nWorkers)
|
||||
.setPSLaunchCmd(psLaunchCommand).setWorkerLaunchCmd(workerLaunchCmd)
|
||||
.setPsResource(psResource)
|
||||
.setTensorboardEnabled(tensorboard)
|
||||
.setKeytab(kerberosKeytab)
|
||||
.setPrincipal(kerberosPrincipal)
|
||||
.setDistributeKeytab(distributeKerberosKeytab)
|
||||
|
@ -203,6 +135,39 @@ public class RunJobParameters extends RunParameters {
|
|||
super.updateParameters(parametersHolder, clientContext);
|
||||
}
|
||||
|
||||
abstract void executePostOperations(ClientContext clientContext)
|
||||
throws IOException;
|
||||
|
||||
void setDefaultDirs(ClientContext clientContext) throws IOException {
|
||||
// Create directories if needed
|
||||
String jobDir = getCheckpointPath();
|
||||
if (jobDir == null) {
|
||||
jobDir = getJobDir(clientContext);
|
||||
setCheckpointPath(jobDir);
|
||||
}
|
||||
|
||||
if (getNumWorkers() > 0) {
|
||||
String savedModelDir = getSavedModelPath();
|
||||
if (savedModelDir == null) {
|
||||
savedModelDir = jobDir;
|
||||
setSavedModelPath(savedModelDir);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String getJobDir(ClientContext clientContext) throws IOException {
|
||||
RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
|
||||
if (getNumWorkers() > 0) {
|
||||
return rdm.getJobCheckpointDir(getName(), true).toString();
|
||||
} else {
|
||||
// when #workers == 0, it means we only launch TB. In that case,
|
||||
// point job dir to root dir so all job's metrics will be shown.
|
||||
return rdm.getUserRootFolder().toString();
|
||||
}
|
||||
}
|
||||
|
||||
public abstract List<String> getLaunchCommands();
|
||||
|
||||
public String getInputPath() {
|
||||
return input;
|
||||
}
|
||||
|
@ -221,110 +186,10 @@ public class RunJobParameters extends RunParameters {
|
|||
return this;
|
||||
}
|
||||
|
||||
public int getNumWorkers() {
|
||||
return numWorkers;
|
||||
}
|
||||
|
||||
public RunJobParameters setNumWorkers(int numWorkers) {
|
||||
this.numWorkers = numWorkers;
|
||||
return this;
|
||||
}
|
||||
|
||||
public int getNumPS() {
|
||||
return numPS;
|
||||
}
|
||||
|
||||
public RunJobParameters setNumPS(int numPS) {
|
||||
this.numPS = numPS;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Resource getWorkerResource() {
|
||||
return workerResource;
|
||||
}
|
||||
|
||||
public RunJobParameters setWorkerResource(Resource workerResource) {
|
||||
this.workerResource = workerResource;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Resource getPsResource() {
|
||||
return psResource;
|
||||
}
|
||||
|
||||
public RunJobParameters setPsResource(Resource psResource) {
|
||||
this.psResource = psResource;
|
||||
return this;
|
||||
}
|
||||
|
||||
public boolean isTensorboardEnabled() {
|
||||
return tensorboardEnabled;
|
||||
}
|
||||
|
||||
public RunJobParameters setTensorboardEnabled(boolean tensorboardEnabled) {
|
||||
this.tensorboardEnabled = tensorboardEnabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public String getWorkerLaunchCmd() {
|
||||
return workerLaunchCmd;
|
||||
}
|
||||
|
||||
public RunJobParameters setWorkerLaunchCmd(String workerLaunchCmd) {
|
||||
this.workerLaunchCmd = workerLaunchCmd;
|
||||
return this;
|
||||
}
|
||||
|
||||
public String getPSLaunchCmd() {
|
||||
return psLaunchCmd;
|
||||
}
|
||||
|
||||
public RunJobParameters setPSLaunchCmd(String psLaunchCmd) {
|
||||
this.psLaunchCmd = psLaunchCmd;
|
||||
return this;
|
||||
}
|
||||
|
||||
public boolean isWaitJobFinish() {
|
||||
return waitJobFinish;
|
||||
}
|
||||
|
||||
|
||||
public String getPsDockerImage() {
|
||||
return psDockerImage;
|
||||
}
|
||||
|
||||
public void setPsDockerImage(String psDockerImage) {
|
||||
this.psDockerImage = psDockerImage;
|
||||
}
|
||||
|
||||
public String getWorkerDockerImage() {
|
||||
return workerDockerImage;
|
||||
}
|
||||
|
||||
public void setWorkerDockerImage(String workerDockerImage) {
|
||||
this.workerDockerImage = workerDockerImage;
|
||||
}
|
||||
|
||||
public boolean isDistributed() {
|
||||
return distributed;
|
||||
}
|
||||
|
||||
public Resource getTensorboardResource() {
|
||||
return tensorboardResource;
|
||||
}
|
||||
|
||||
public void setTensorboardResource(Resource tensorboardResource) {
|
||||
this.tensorboardResource = tensorboardResource;
|
||||
}
|
||||
|
||||
public String getTensorboardDockerImage() {
|
||||
return tensorboardDockerImage;
|
||||
}
|
||||
|
||||
public void setTensorboardDockerImage(String tensorboardDockerImage) {
|
||||
this.tensorboardDockerImage = tensorboardDockerImage;
|
||||
}
|
||||
|
||||
public List<Quicklink> getQuicklinks() {
|
||||
return quicklinks;
|
||||
}
|
||||
|
@ -382,6 +247,90 @@ public class RunJobParameters extends RunParameters {
|
|||
this.distributed = distributed;
|
||||
}
|
||||
|
||||
RoleParameters getWorkerParameters(ClientContext clientContext,
|
||||
ParametersHolder parametersHolder, String input)
|
||||
throws ParseException, YarnException, IOException {
|
||||
int nWorkers = getNumberOfWorkers(parametersHolder, input);
|
||||
Resource workerResource =
|
||||
determineWorkerResource(parametersHolder, nWorkers, clientContext);
|
||||
String workerDockerImage =
|
||||
parametersHolder.getOptionValue(CliConstants.WORKER_DOCKER_IMAGE);
|
||||
String workerLaunchCmd =
|
||||
parametersHolder.getOptionValue(CliConstants.WORKER_LAUNCH_CMD);
|
||||
return new RoleParameters(TensorFlowRole.WORKER, nWorkers,
|
||||
workerLaunchCmd, workerDockerImage, workerResource);
|
||||
}
|
||||
|
||||
private Resource determineWorkerResource(ParametersHolder parametersHolder,
|
||||
int nWorkers, ClientContext clientContext)
|
||||
throws ParseException, YarnException, IOException {
|
||||
if (nWorkers > 0) {
|
||||
String workerResourceStr =
|
||||
parametersHolder.getOptionValue(CliConstants.WORKER_RES);
|
||||
if (workerResourceStr == null) {
|
||||
throw new ParseException(
|
||||
"--" + CliConstants.WORKER_RES + " is absent.");
|
||||
}
|
||||
return ResourceUtils.createResourceFromString(workerResourceStr,
|
||||
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private int getNumberOfWorkers(ParametersHolder parametersHolder,
|
||||
String input) throws ParseException, YarnException {
|
||||
int nWorkers = 1;
|
||||
if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
|
||||
nWorkers = Integer
|
||||
.parseInt(parametersHolder.getOptionValue(CliConstants.N_WORKERS));
|
||||
// Only check null value.
|
||||
// Training job shouldn't ignore INPUT_PATH option
|
||||
// But if nWorkers is 0, INPUT_PATH can be ignored because
|
||||
// user can only run Tensorboard
|
||||
if (null == input && 0 != nWorkers) {
|
||||
throw new ParseException(
|
||||
"\"--" + CliConstants.INPUT_PATH + "\" is absent");
|
||||
}
|
||||
}
|
||||
return nWorkers;
|
||||
}
|
||||
|
||||
public String getWorkerLaunchCmd() {
|
||||
return workerParameters.getLaunchCommand();
|
||||
}
|
||||
|
||||
public void setWorkerLaunchCmd(String launchCmd) {
|
||||
workerParameters.setLaunchCommand(launchCmd);
|
||||
}
|
||||
|
||||
public int getNumWorkers() {
|
||||
return workerParameters.getReplicas();
|
||||
}
|
||||
|
||||
public void setNumWorkers(int numWorkers) {
|
||||
workerParameters.setReplicas(numWorkers);
|
||||
}
|
||||
|
||||
public Resource getWorkerResource() {
|
||||
return workerParameters.getResource();
|
||||
}
|
||||
|
||||
public void setWorkerResource(Resource resource) {
|
||||
workerParameters.setResource(resource);
|
||||
}
|
||||
|
||||
public String getWorkerDockerImage() {
|
||||
return workerParameters.getDockerImage();
|
||||
}
|
||||
|
||||
public void setWorkerDockerImage(String image) {
|
||||
workerParameters.setDockerImage(image);
|
||||
}
|
||||
|
||||
public boolean isDistributed() {
|
||||
return distributed;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
|
||||
@Override
|
|
@ -0,0 +1,215 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.commons.cli.ParseException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.hadoop.yarn.api.records.Resource;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Parameters for TensorFlow job.
|
||||
*/
|
||||
public class TensorFlowRunJobParameters extends RunJobParameters {
|
||||
private boolean tensorboardEnabled;
|
||||
private RoleParameters psParameters =
|
||||
RoleParameters.createEmpty(TensorFlowRole.PS);
|
||||
private RoleParameters tensorBoardParameters =
|
||||
RoleParameters.createEmpty(TensorFlowRole.TENSORBOARD);
|
||||
|
||||
@Override
|
||||
public void updateParameters(ParametersHolder parametersHolder,
|
||||
ClientContext clientContext)
|
||||
throws ParseException, IOException, YarnException {
|
||||
super.updateParameters(parametersHolder, clientContext);
|
||||
|
||||
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
|
||||
this.workerParameters =
|
||||
getWorkerParameters(clientContext, parametersHolder, input);
|
||||
this.psParameters = getPSParameters(clientContext, parametersHolder);
|
||||
this.distributed = determineIfDistributed(workerParameters.getReplicas(),
|
||||
psParameters.getReplicas());
|
||||
|
||||
if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
|
||||
this.tensorboardEnabled = true;
|
||||
this.tensorBoardParameters =
|
||||
getTensorBoardParameters(parametersHolder, clientContext);
|
||||
}
|
||||
executePostOperations(clientContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
void executePostOperations(ClientContext clientContext) throws IOException {
|
||||
// Set default job dir / saved model dir, etc.
|
||||
setDefaultDirs(clientContext);
|
||||
replacePatternsInParameters(clientContext);
|
||||
}
|
||||
|
||||
private void replacePatternsInParameters(ClientContext clientContext)
|
||||
throws IOException {
|
||||
if (StringUtils.isNotEmpty(getPSLaunchCmd())) {
|
||||
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
|
||||
getPSLaunchCmd(), this, clientContext.getRemoteDirectoryManager());
|
||||
setPSLaunchCmd(afterReplace);
|
||||
}
|
||||
|
||||
if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
|
||||
String afterReplace =
|
||||
CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
|
||||
clientContext.getRemoteDirectoryManager());
|
||||
setWorkerLaunchCmd(afterReplace);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getLaunchCommands() {
|
||||
return Lists.newArrayList(getWorkerLaunchCmd(), getPSLaunchCmd());
|
||||
}
|
||||
|
||||
private boolean determineIfDistributed(int nWorkers, int nPS)
|
||||
throws ParseException {
|
||||
// Check #workers and #ps.
|
||||
// When distributed training is required
|
||||
if (nWorkers >= 2 && nPS > 0) {
|
||||
return true;
|
||||
} else if (nWorkers <= 1 && nPS > 0) {
|
||||
throw new ParseException("Only specified one worker but non-zero PS, "
|
||||
+ "please double check.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private RoleParameters getPSParameters(ClientContext clientContext,
|
||||
ParametersHolder parametersHolder)
|
||||
throws YarnException, IOException, ParseException {
|
||||
int nPS = getNumberOfPS(parametersHolder);
|
||||
Resource psResource =
|
||||
determinePSResource(parametersHolder, nPS, clientContext);
|
||||
String psDockerImage =
|
||||
parametersHolder.getOptionValue(CliConstants.PS_DOCKER_IMAGE);
|
||||
String psLaunchCommand =
|
||||
parametersHolder.getOptionValue(CliConstants.PS_LAUNCH_CMD);
|
||||
return new RoleParameters(TensorFlowRole.PS, nPS, psLaunchCommand,
|
||||
psDockerImage, psResource);
|
||||
}
|
||||
|
||||
private Resource determinePSResource(ParametersHolder parametersHolder,
|
||||
int nPS, ClientContext clientContext)
|
||||
throws ParseException, YarnException, IOException {
|
||||
if (nPS > 0) {
|
||||
String psResourceStr =
|
||||
parametersHolder.getOptionValue(CliConstants.PS_RES);
|
||||
if (psResourceStr == null) {
|
||||
throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
|
||||
}
|
||||
return ResourceUtils.createResourceFromString(psResourceStr,
|
||||
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private int getNumberOfPS(ParametersHolder parametersHolder)
|
||||
throws YarnException {
|
||||
int nPS = 0;
|
||||
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
|
||||
nPS =
|
||||
Integer.parseInt(parametersHolder.getOptionValue(CliConstants.N_PS));
|
||||
}
|
||||
return nPS;
|
||||
}
|
||||
|
||||
private RoleParameters getTensorBoardParameters(
|
||||
ParametersHolder parametersHolder, ClientContext clientContext)
|
||||
throws YarnException, IOException {
|
||||
String tensorboardResourceStr =
|
||||
parametersHolder.getOptionValue(CliConstants.TENSORBOARD_RESOURCES);
|
||||
if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
|
||||
tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
|
||||
}
|
||||
Resource tensorboardResource =
|
||||
ResourceUtils.createResourceFromString(tensorboardResourceStr,
|
||||
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
||||
String tensorboardDockerImage =
|
||||
parametersHolder.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE);
|
||||
return new RoleParameters(TensorFlowRole.TENSORBOARD, 1, null,
|
||||
tensorboardDockerImage, tensorboardResource);
|
||||
}
|
||||
|
||||
public int getNumPS() {
|
||||
return psParameters.getReplicas();
|
||||
}
|
||||
|
||||
public void setNumPS(int numPS) {
|
||||
psParameters.setReplicas(numPS);
|
||||
}
|
||||
|
||||
public Resource getPsResource() {
|
||||
return psParameters.getResource();
|
||||
}
|
||||
|
||||
public void setPsResource(Resource resource) {
|
||||
psParameters.setResource(resource);
|
||||
}
|
||||
|
||||
public String getPsDockerImage() {
|
||||
return psParameters.getDockerImage();
|
||||
}
|
||||
|
||||
public void setPsDockerImage(String image) {
|
||||
psParameters.setDockerImage(image);
|
||||
}
|
||||
|
||||
public String getPSLaunchCmd() {
|
||||
return psParameters.getLaunchCommand();
|
||||
}
|
||||
|
||||
public void setPSLaunchCmd(String launchCmd) {
|
||||
psParameters.setLaunchCommand(launchCmd);
|
||||
}
|
||||
|
||||
public boolean isTensorboardEnabled() {
|
||||
return tensorboardEnabled;
|
||||
}
|
||||
|
||||
public Resource getTensorboardResource() {
|
||||
return tensorBoardParameters.getResource();
|
||||
}
|
||||
|
||||
public void setTensorboardResource(Resource resource) {
|
||||
tensorBoardParameters.setResource(resource);
|
||||
}
|
||||
|
||||
public String getTensorboardDockerImage() {
|
||||
return tensorBoardParameters.getDockerImage();
|
||||
}
|
||||
|
||||
public void setTensorboardDockerImage(String image) {
|
||||
tensorBoardParameters.setDockerImage(image);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
* This package contains classes that hold run job parameters for
|
||||
* TensorFlow / PyTorch jobs.
|
||||
*/
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
|
|
@ -22,6 +22,7 @@ package org.apache.hadoop.yarn.submarine.client.cli.param.yaml;
|
|||
public class Spec {
|
||||
private String name;
|
||||
private String jobType;
|
||||
private String framework;
|
||||
|
||||
public String getJobType() {
|
||||
return jobType;
|
||||
|
@ -38,4 +39,12 @@ public class Spec {
|
|||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public String getFramework() {
|
||||
return framework;
|
||||
}
|
||||
|
||||
public void setFramework(String framework) {
|
||||
this.framework = framework;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Represents the type of Machine learning framework to work with.
|
||||
*/
|
||||
public enum Framework {
|
||||
TENSORFLOW(Constants.TENSORFLOW_NAME), PYTORCH(Constants.PYTORCH_NAME);
|
||||
|
||||
private String value;
|
||||
|
||||
Framework(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public static Framework parseByValue(String value) {
|
||||
for (Framework fw : Framework.values()) {
|
||||
if (fw.value.equalsIgnoreCase(value)) {
|
||||
return fw;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static String getValues() {
|
||||
List<String> values = Lists.newArrayList(Framework.values()).stream()
|
||||
.map(fw -> fw.value).collect(Collectors.toList());
|
||||
return String.join(",", values);
|
||||
}
|
||||
|
||||
private static class Constants {
|
||||
static final String TENSORFLOW_NAME = "tensorflow";
|
||||
static final String PYTORCH_NAME = "pytorch";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||
|
||||
import org.apache.hadoop.yarn.api.records.Resource;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
|
||||
/**
|
||||
* This class encapsulates data related to a particular Role.
|
||||
* Some examples: TF Worker process, TF PS process or a PyTorch worker process.
|
||||
*/
|
||||
public class RoleParameters {
|
||||
private final Role role;
|
||||
private int replicas;
|
||||
private String launchCommand;
|
||||
private String dockerImage;
|
||||
private Resource resource;
|
||||
|
||||
public RoleParameters(Role role, int replicas,
|
||||
String launchCommand, String dockerImage, Resource resource) {
|
||||
this.role = role;
|
||||
this.replicas = replicas;
|
||||
this.launchCommand = launchCommand;
|
||||
this.dockerImage = dockerImage;
|
||||
this.resource = resource;
|
||||
}
|
||||
|
||||
public static RoleParameters createEmpty(Role role) {
|
||||
return new RoleParameters(role, 0, null, null, null);
|
||||
}
|
||||
|
||||
public Role getRole() {
|
||||
return role;
|
||||
}
|
||||
|
||||
public int getReplicas() {
|
||||
return replicas;
|
||||
}
|
||||
|
||||
public String getLaunchCommand() {
|
||||
return launchCommand;
|
||||
}
|
||||
|
||||
public void setLaunchCommand(String launchCommand) {
|
||||
this.launchCommand = launchCommand;
|
||||
}
|
||||
|
||||
public String getDockerImage() {
|
||||
return dockerImage;
|
||||
}
|
||||
|
||||
public void setDockerImage(String dockerImage) {
|
||||
this.dockerImage = dockerImage;
|
||||
}
|
||||
|
||||
public Resource getResource() {
|
||||
return resource;
|
||||
}
|
||||
|
||||
public void setResource(Resource resource) {
|
||||
this.resource = resource;
|
||||
}
|
||||
|
||||
public void setReplicas(int replicas) {
|
||||
this.replicas = replicas;
|
||||
}
|
||||
}
|
|
@ -1,3 +1,19 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -12,7 +28,7 @@
|
|||
* limitations under the License. See accompanying LICENSE file.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import org.apache.commons.cli.CommandLine;
|
||||
|
@ -23,9 +39,13 @@ import org.apache.commons.cli.ParseException;
|
|||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.AbstractCli;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.Command;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
|
@ -44,17 +64,25 @@ import java.io.IOException;
|
|||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* This purpose of this class is to handle / parse CLI arguments related to
|
||||
* the run job Submarine command.
|
||||
*/
|
||||
public class RunJobCli extends AbstractCli {
|
||||
private static final Logger LOG =
|
||||
LoggerFactory.getLogger(RunJobCli.class);
|
||||
private static final String YAML_PARSE_FAILED = "Failed to parse " +
|
||||
private static final String CAN_BE_USED_WITH_TF_PYTORCH =
|
||||
"Can be used with TensorFlow or PyTorch frameworks.";
|
||||
private static final String CAN_BE_USED_WITH_TF_ONLY =
|
||||
"Can only be used with TensorFlow framework.";
|
||||
public static final String YAML_PARSE_FAILED = "Failed to parse " +
|
||||
"YAML config";
|
||||
|
||||
private Options options;
|
||||
private RunJobParameters parameters = new RunJobParameters();
|
||||
|
||||
private Options options;
|
||||
private JobSubmitter jobSubmitter;
|
||||
private JobMonitor jobMonitor;
|
||||
private ParametersHolder parametersHolder;
|
||||
|
||||
public RunJobCli(ClientContext cliContext) {
|
||||
this(cliContext, cliContext.getRuntimeFactory().getJobSubmitterInstance(),
|
||||
|
@ -62,7 +90,7 @@ public class RunJobCli extends AbstractCli {
|
|||
}
|
||||
|
||||
@VisibleForTesting
|
||||
RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
|
||||
public RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
|
||||
JobMonitor jobMonitor) {
|
||||
super(cliContext);
|
||||
this.options = generateOptions();
|
||||
|
@ -78,6 +106,10 @@ public class RunJobCli extends AbstractCli {
|
|||
Options options = new Options();
|
||||
options.addOption(CliConstants.YAML_CONFIG, true,
|
||||
"Config file (in YAML format)");
|
||||
options.addOption(CliConstants.FRAMEWORK, true,
|
||||
String.format("Framework to use. Valid values are: %s! " +
|
||||
"The default framework is Tensorflow.",
|
||||
Framework.getValues()));
|
||||
options.addOption(CliConstants.NAME, true, "Name of the job");
|
||||
options.addOption(CliConstants.INPUT_PATH, true,
|
||||
"Input of the job, could be local or other FS directory");
|
||||
|
@ -88,48 +120,22 @@ public class RunJobCli extends AbstractCli {
|
|||
options.addOption(CliConstants.SAVED_MODEL_PATH, true,
|
||||
"Model exported path (savedmodel) of the job, which is needed when "
|
||||
+ "exported model is not placed under ${checkpoint_path}"
|
||||
+ "could be local or other FS directory. This will be used to serve.");
|
||||
options.addOption(CliConstants.N_WORKERS, true,
|
||||
"Number of worker tasks of the job, by default it's 1");
|
||||
options.addOption(CliConstants.N_PS, true,
|
||||
"Number of PS tasks of the job, by default it's 0");
|
||||
options.addOption(CliConstants.WORKER_RES, true,
|
||||
"Resource of each worker, for example "
|
||||
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2");
|
||||
options.addOption(CliConstants.PS_RES, true,
|
||||
"Resource of each PS, for example "
|
||||
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2");
|
||||
+ "could be local or other FS directory. " +
|
||||
"This will be used to serve.");
|
||||
options.addOption(CliConstants.DOCKER_IMAGE, true, "Docker image name/tag");
|
||||
options.addOption(CliConstants.QUEUE, true,
|
||||
"Name of queue to run the job, by default it uses default queue");
|
||||
options.addOption(CliConstants.TENSORBOARD, false,
|
||||
"Should we run TensorBoard"
|
||||
+ " for this job? By default it's disabled");
|
||||
options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
|
||||
"Specify resources of Tensorboard, by default it is "
|
||||
+ CliConstants.TENSORBOARD_DEFAULT_RESOURCES);
|
||||
options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
|
||||
"Specify Tensorboard docker image. when this is not "
|
||||
+ "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
|
||||
+ " as default.");
|
||||
options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
|
||||
"Commandline of worker, arguments will be "
|
||||
+ "directly used to launch the worker");
|
||||
options.addOption(CliConstants.PS_LAUNCH_CMD, true,
|
||||
"Commandline of worker, arguments will be "
|
||||
+ "directly used to launch the PS");
|
||||
|
||||
addWorkerOptions(options);
|
||||
addPSOptions(options);
|
||||
addTensorboardOptions(options);
|
||||
|
||||
options.addOption(CliConstants.ENV, true,
|
||||
"Common environment variable of worker/ps");
|
||||
options.addOption(CliConstants.VERBOSE, false,
|
||||
"Print verbose log for troubleshooting");
|
||||
options.addOption(CliConstants.WAIT_JOB_FINISH, false,
|
||||
"Specified when user want to wait the job finish");
|
||||
options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
|
||||
"Specify docker image for PS, when this is not specified, PS uses --"
|
||||
+ CliConstants.DOCKER_IMAGE + " as default.");
|
||||
options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
|
||||
"Specify docker image for WORKER, when this is not specified, WORKER "
|
||||
+ "uses --" + CliConstants.DOCKER_IMAGE + " as default.");
|
||||
options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN"
|
||||
+ "web UI shows link to given role instance and port. When "
|
||||
+ "--tensorboard is specified, quicklink to tensorboard instance will "
|
||||
|
@ -172,63 +178,97 @@ public class RunJobCli extends AbstractCli {
|
|||
return options;
|
||||
}
|
||||
|
||||
private void replacePatternsInParameters() throws IOException {
|
||||
if (parameters.getPSLaunchCmd() != null && !parameters.getPSLaunchCmd()
|
||||
.isEmpty()) {
|
||||
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
|
||||
parameters.getPSLaunchCmd(), parameters,
|
||||
clientContext.getRemoteDirectoryManager());
|
||||
parameters.setPSLaunchCmd(afterReplace);
|
||||
}
|
||||
private void addWorkerOptions(Options options) {
|
||||
options.addOption(CliConstants.N_WORKERS, true,
|
||||
"Number of worker tasks of the job, by default it's 1." +
|
||||
CAN_BE_USED_WITH_TF_PYTORCH);
|
||||
options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
|
||||
"Specify docker image for WORKER, when this is not specified, WORKER "
|
||||
+ "uses --" + CliConstants.DOCKER_IMAGE + " as default." +
|
||||
CAN_BE_USED_WITH_TF_PYTORCH);
|
||||
options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
|
||||
"Commandline of worker, arguments will be "
|
||||
+ "directly used to launch the worker" +
|
||||
CAN_BE_USED_WITH_TF_PYTORCH);
|
||||
options.addOption(CliConstants.WORKER_RES, true,
|
||||
"Resource of each worker, for example "
|
||||
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
|
||||
CAN_BE_USED_WITH_TF_PYTORCH);
|
||||
}
|
||||
|
||||
if (parameters.getWorkerLaunchCmd() != null && !parameters
|
||||
.getWorkerLaunchCmd().isEmpty()) {
|
||||
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
|
||||
parameters.getWorkerLaunchCmd(), parameters,
|
||||
clientContext.getRemoteDirectoryManager());
|
||||
parameters.setWorkerLaunchCmd(afterReplace);
|
||||
}
|
||||
private void addPSOptions(Options options) {
|
||||
options.addOption(CliConstants.N_PS, true,
|
||||
"Number of PS tasks of the job, by default it's 0. " +
|
||||
CAN_BE_USED_WITH_TF_ONLY);
|
||||
options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
|
||||
"Specify docker image for PS, when this is not specified, PS uses --"
|
||||
+ CliConstants.DOCKER_IMAGE + " as default." +
|
||||
CAN_BE_USED_WITH_TF_ONLY);
|
||||
options.addOption(CliConstants.PS_LAUNCH_CMD, true,
|
||||
"Commandline of worker, arguments will be "
|
||||
+ "directly used to launch the PS" +
|
||||
CAN_BE_USED_WITH_TF_ONLY);
|
||||
options.addOption(CliConstants.PS_RES, true,
|
||||
"Resource of each PS, for example "
|
||||
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
|
||||
CAN_BE_USED_WITH_TF_ONLY);
|
||||
}
|
||||
|
||||
private void addTensorboardOptions(Options options) {
|
||||
options.addOption(CliConstants.TENSORBOARD, false,
|
||||
"Should we run TensorBoard"
|
||||
+ " for this job? By default it's disabled." +
|
||||
CAN_BE_USED_WITH_TF_ONLY);
|
||||
options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
|
||||
"Specify resources of Tensorboard, by default it is "
|
||||
+ CliConstants.TENSORBOARD_DEFAULT_RESOURCES + "." +
|
||||
CAN_BE_USED_WITH_TF_ONLY);
|
||||
options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
|
||||
"Specify Tensorboard docker image. when this is not "
|
||||
+ "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
|
||||
+ " as default." +
|
||||
CAN_BE_USED_WITH_TF_ONLY);
|
||||
}
|
||||
|
||||
private void parseCommandLineAndGetRunJobParameters(String[] args)
|
||||
throws ParseException, IOException, YarnException {
|
||||
try {
|
||||
// Do parsing
|
||||
GnuParser parser = new GnuParser();
|
||||
CommandLine cli = parser.parse(options, args);
|
||||
ParametersHolder parametersHolder = createParametersHolder(cli);
|
||||
parameters.updateParameters(parametersHolder, clientContext);
|
||||
parametersHolder = createParametersHolder(cli);
|
||||
parametersHolder.updateParameters(clientContext);
|
||||
} catch (ParseException e) {
|
||||
LOG.error("Exception in parse: {}", e.getMessage());
|
||||
printUsages();
|
||||
throw e;
|
||||
}
|
||||
|
||||
// Set default job dir / saved model dir, etc.
|
||||
setDefaultDirs();
|
||||
|
||||
// replace patterns
|
||||
replacePatternsInParameters();
|
||||
}
|
||||
|
||||
private ParametersHolder createParametersHolder(CommandLine cli) {
|
||||
private ParametersHolder createParametersHolder(CommandLine cli)
|
||||
throws ParseException, YarnException {
|
||||
String yamlConfigFile =
|
||||
cli.getOptionValue(CliConstants.YAML_CONFIG);
|
||||
if (yamlConfigFile != null) {
|
||||
YamlConfigFile yamlConfig = readYamlConfigFile(yamlConfigFile);
|
||||
if (yamlConfig == null) {
|
||||
throw new YamlParseException(String.format(
|
||||
YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
|
||||
} else if (yamlConfig.getConfigs() == null) {
|
||||
throw new YamlParseException(String.format(YAML_PARSE_FAILED +
|
||||
", config section should be defined, but it cannot be found in " +
|
||||
"YAML file '%s'!", yamlConfigFile));
|
||||
}
|
||||
checkYamlConfig(yamlConfigFile, yamlConfig);
|
||||
LOG.info("Using YAML configuration!");
|
||||
return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig);
|
||||
return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig,
|
||||
Command.RUN_JOB);
|
||||
} else {
|
||||
LOG.info("Using CLI configuration!");
|
||||
return ParametersHolder.createWithCmdLine(cli);
|
||||
return ParametersHolder.createWithCmdLine(cli, Command.RUN_JOB);
|
||||
}
|
||||
}
|
||||
|
||||
private void checkYamlConfig(String yamlConfigFile,
|
||||
YamlConfigFile yamlConfig) {
|
||||
if (yamlConfig == null) {
|
||||
throw new YamlParseException(String.format(
|
||||
YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
|
||||
} else if (yamlConfig.getConfigs() == null) {
|
||||
throw new YamlParseException(String.format(YAML_PARSE_FAILED +
|
||||
", config section should be defined, but it cannot be found in " +
|
||||
"YAML file '%s'!", yamlConfigFile));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -256,34 +296,9 @@ public class RunJobCli extends AbstractCli {
|
|||
e);
|
||||
}
|
||||
|
||||
private void setDefaultDirs() throws IOException {
|
||||
// Create directories if needed
|
||||
String jobDir = parameters.getCheckpointPath();
|
||||
if (null == jobDir) {
|
||||
if (parameters.getNumWorkers() > 0) {
|
||||
jobDir = clientContext.getRemoteDirectoryManager().getJobCheckpointDir(
|
||||
parameters.getName(), true).toString();
|
||||
} else {
|
||||
// when #workers == 0, it means we only launch TB. In that case,
|
||||
// point job dir to root dir so all job's metrics will be shown.
|
||||
jobDir = clientContext.getRemoteDirectoryManager().getUserRootFolder()
|
||||
.toString();
|
||||
}
|
||||
parameters.setCheckpointPath(jobDir);
|
||||
}
|
||||
|
||||
if (parameters.getNumWorkers() > 0) {
|
||||
// Only do this when #worker > 0
|
||||
String savedModelDir = parameters.getSavedModelPath();
|
||||
if (null == savedModelDir) {
|
||||
savedModelDir = jobDir;
|
||||
parameters.setSavedModelPath(savedModelDir);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void storeJobInformation(String jobName, ApplicationId applicationId,
|
||||
String[] args) throws IOException {
|
||||
private void storeJobInformation(RunJobParameters parameters,
|
||||
ApplicationId applicationId, String[] args) throws IOException {
|
||||
String jobName = parameters.getName();
|
||||
Map<String, String> jobInfo = new HashMap<>();
|
||||
jobInfo.put(StorageKeyConstants.JOB_NAME, jobName);
|
||||
jobInfo.put(StorageKeyConstants.APPLICATION_ID, applicationId.toString());
|
||||
|
@ -316,8 +331,10 @@ public class RunJobCli extends AbstractCli {
|
|||
}
|
||||
|
||||
parseCommandLineAndGetRunJobParameters(args);
|
||||
ApplicationId applicationId = this.jobSubmitter.submitJob(parameters);
|
||||
storeJobInformation(parameters.getName(), applicationId, args);
|
||||
ApplicationId applicationId = jobSubmitter.submitJob(parametersHolder);
|
||||
RunJobParameters parameters =
|
||||
(RunJobParameters) parametersHolder.getParameters();
|
||||
storeJobInformation(parameters, applicationId, args);
|
||||
if (parameters.isWaitJobFinish()) {
|
||||
this.jobMonitor.waitTrainingFinal(parameters.getName());
|
||||
}
|
||||
|
@ -332,6 +349,6 @@ public class RunJobCli extends AbstractCli {
|
|||
|
||||
@VisibleForTesting
|
||||
public RunJobParameters getRunJobParameters() {
|
||||
return parameters;
|
||||
return (RunJobParameters) parametersHolder.getParameters();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
* This package contains classes that are related to the run job command.
|
||||
*/
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. See accompanying LICENSE file.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.common.api;
|
||||
|
||||
/**
|
||||
* Enum to represent a PyTorch Role.
|
||||
*/
|
||||
public enum PyTorchRole implements Role {
|
||||
PRIMARY_WORKER("master"),
|
||||
WORKER("worker");
|
||||
|
||||
private String compName;
|
||||
|
||||
PyTorchRole(String compName) {
|
||||
this.compName = compName;
|
||||
}
|
||||
|
||||
public String getComponentName() {
|
||||
return compName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.common.api;
|
||||
|
||||
/**
|
||||
* Interface for a Role.
|
||||
*/
|
||||
public interface Role {
|
||||
String getComponentName();
|
||||
String getName();
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.common.api;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Represents the type of Runtime.
|
||||
*/
|
||||
public enum Runtime {
|
||||
TONY(Constants.TONY), YARN_SERVICE(Constants.YARN_SERVICE);
|
||||
|
||||
private String value;
|
||||
|
||||
Runtime(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public static Runtime parseByValue(String value) {
|
||||
for (Runtime rt : Runtime.values()) {
|
||||
if (rt.value.equalsIgnoreCase(value)) {
|
||||
return rt;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static String getValues() {
|
||||
List<String> values = Lists.newArrayList(Runtime.values()).stream()
|
||||
.map(rt -> rt.value).collect(Collectors.toList());
|
||||
return String.join(",", values);
|
||||
}
|
||||
|
||||
public static class Constants {
|
||||
public static final String TONY = "tony";
|
||||
public static final String YARN_SERVICE = "yarnservice";
|
||||
}
|
||||
}
|
|
@ -14,7 +14,10 @@
|
|||
|
||||
package org.apache.hadoop.yarn.submarine.common.api;
|
||||
|
||||
public enum TaskType {
|
||||
/**
|
||||
* Enum to represent a TensorFlow Role.
|
||||
*/
|
||||
public enum TensorFlowRole implements Role {
|
||||
PRIMARY_WORKER("master"),
|
||||
WORKER("worker"),
|
||||
PS("ps"),
|
||||
|
@ -22,11 +25,17 @@ public enum TaskType {
|
|||
|
||||
private String compName;
|
||||
|
||||
TaskType(String compName) {
|
||||
TensorFlowRole(String compName) {
|
||||
this.compName = compName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getComponentName() {
|
||||
return compName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name();
|
||||
}
|
||||
}
|
|
@ -16,21 +16,21 @@ package org.apache.hadoop.yarn.submarine.runtimes.common;
|
|||
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Submit job to cluster master
|
||||
* Submit job to cluster master.
|
||||
*/
|
||||
public interface JobSubmitter {
|
||||
/**
|
||||
* Submit job to cluster
|
||||
* Submit a job to cluster.
|
||||
* @param parameters run job parameters
|
||||
* @return applicatioId when successfully submitted
|
||||
* @return applicationId when successfully submitted
|
||||
* @throws YarnException for issues while contacting YARN daemons
|
||||
* @throws IOException for other issues.
|
||||
*/
|
||||
ApplicationId submitJob(RunJobParameters parameters)
|
||||
ApplicationId submitJob(ParametersHolder parameters)
|
||||
throws IOException, YarnException;
|
||||
}
|
||||
|
|
|
@ -40,6 +40,10 @@ More details, please refer to
|
|||
|
||||
```$xslt
|
||||
usage: job run
|
||||
|
||||
-framework <arg> Framework to use.
|
||||
Valid values are: tensorflow, pytorch.
|
||||
The default framework is Tensorflow.
|
||||
-checkpoint_path <arg> Training output directory of the job, could
|
||||
be local or other FS directory. This
|
||||
typically includes checkpoint files and
|
||||
|
@ -130,6 +134,7 @@ For submarine internal configuration, please create a `submarine.xml` which shou
|
|||
#### Commandline
|
||||
```
|
||||
yarn jar path-to/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar job run \
|
||||
--framework tensorflow \
|
||||
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
|
||||
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current --name tf-job-001 \
|
||||
--docker_image <your-docker-image> \
|
||||
|
@ -163,6 +168,7 @@ See below screenshot:
|
|||
```
|
||||
yarn jar hadoop-yarn-applications-submarine-<version>.jar job run \
|
||||
--name tf-job-001 --docker_image <your-docker-image> \
|
||||
--framework tensorflow \
|
||||
--input_path hdfs://default/dataset/cifar-10-data \
|
||||
--checkpoint_path hdfs://default/tmp/cifar-10-jobdir \
|
||||
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
|
||||
|
@ -208,6 +214,7 @@ After that, you can run ```tensorboard --logdir=<checkpoint-path>``` to view Ten
|
|||
yarn app -destroy tensorboard-service; \
|
||||
yarn jar /tmp/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar \
|
||||
job run --name tensorboard-service --verbose --docker_image <your-docker-image> \
|
||||
--framework tensorflow \
|
||||
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
|
||||
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current \
|
||||
--num_workers 0 --tensorboard
|
||||
|
|
|
@ -1,226 +0,0 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||
|
||||
import org.apache.commons.cli.ParseException;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
|
||||
import org.apache.hadoop.yarn.util.resource.Resources;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class TestRunJobCliParsing {
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
SubmarineLogs.verboseOff();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPrintHelp() {
|
||||
MockClientContext mockClientContext = new MockClientContext();
|
||||
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
||||
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
||||
RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
|
||||
mockJobMonitor);
|
||||
runJobCli.printUsages();
|
||||
}
|
||||
|
||||
static MockClientContext getMockClientContext()
|
||||
throws IOException, YarnException {
|
||||
MockClientContext mockClientContext = new MockClientContext();
|
||||
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
||||
when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
|
||||
ApplicationId.newInstance(1234L, 1));
|
||||
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
||||
SubmarineStorage storage = mock(SubmarineStorage.class);
|
||||
RuntimeFactory rtFactory = mock(RuntimeFactory.class);
|
||||
|
||||
when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
|
||||
when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
|
||||
when(rtFactory.getSubmarineStorage()).thenReturn(storage);
|
||||
|
||||
mockClientContext.setRuntimeFactory(rtFactory);
|
||||
return mockClientContext;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicRunJobForDistributedTraining() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
runJobCli.run(
|
||||
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
|
||||
"--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
|
||||
"--principal", "user/_HOST@domain.com", "--distribute_keytab",
|
||||
"--verbose" });
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
|
||||
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
||||
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
||||
assertEquals(jobRunParameters.getNumPS(), 2);
|
||||
assertEquals(jobRunParameters.getPSLaunchCmd(), "python run-ps.py");
|
||||
assertEquals(Resources.createResource(4096, 4),
|
||||
jobRunParameters.getPsResource());
|
||||
assertEquals(jobRunParameters.getWorkerLaunchCmd(),
|
||||
"python run-job.py");
|
||||
assertEquals(Resources.createResource(2048, 2),
|
||||
jobRunParameters.getWorkerResource());
|
||||
assertEquals(jobRunParameters.getDockerImageName(),
|
||||
"tf-docker:1.1.0");
|
||||
assertEquals(jobRunParameters.getKeytab(),
|
||||
"/keytab/path");
|
||||
assertEquals(jobRunParameters.getPrincipal(),
|
||||
"user/_HOST@domain.com");
|
||||
Assert.assertTrue(jobRunParameters.isDistributeKeytab());
|
||||
Assert.assertTrue(SubmarineLogs.isVerbose());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicRunJobForSingleNodeTraining() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
runJobCli.run(
|
||||
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||
"true", "--verbose", "--wait_job_finish" });
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
|
||||
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
||||
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
||||
assertEquals(jobRunParameters.getNumWorkers(), 1);
|
||||
assertEquals(jobRunParameters.getWorkerLaunchCmd(),
|
||||
"python run-job.py");
|
||||
assertEquals(Resources.createResource(4096, 2),
|
||||
jobRunParameters.getWorkerResource());
|
||||
Assert.assertTrue(SubmarineLogs.isVerbose());
|
||||
Assert.assertTrue(jobRunParameters.isWaitJobFinish());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoInputPathOptionSpecified() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\" is absent";
|
||||
String actualMessage = "";
|
||||
try {
|
||||
runJobCli.run(
|
||||
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||
"true", "--verbose", "--wait_job_finish"});
|
||||
} catch (ParseException e) {
|
||||
actualMessage = e.getMessage();
|
||||
e.printStackTrace();
|
||||
}
|
||||
assertEquals(expectedErrorMessage, actualMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
* when only run tensorboard, input_path is not needed
|
||||
* */
|
||||
@Test
|
||||
public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
boolean success = true;
|
||||
try {
|
||||
runJobCli.run(
|
||||
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--num_workers", "0", "--tensorboard", "--verbose",
|
||||
"--tensorboard_resources", "memory=2G,vcores=2",
|
||||
"--tensorboard_docker_image", "tb_docker_image:001"});
|
||||
} catch (ParseException e) {
|
||||
success = false;
|
||||
}
|
||||
Assert.assertTrue(success);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJobWithoutName() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
String expectedErrorMessage =
|
||||
"--" + CliConstants.NAME + " is absent";
|
||||
String actualMessage = "";
|
||||
try {
|
||||
runJobCli.run(
|
||||
new String[]{"--docker_image", "tf-docker:1.1.0",
|
||||
"--num_workers", "0", "--tensorboard", "--verbose",
|
||||
"--tensorboard_resources", "memory=2G,vcores=2",
|
||||
"--tensorboard_docker_image", "tb_docker_image:001"});
|
||||
} catch (ParseException e) {
|
||||
actualMessage = e.getMessage();
|
||||
e.printStackTrace();
|
||||
}
|
||||
assertEquals(expectedErrorMessage, actualMessage);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLaunchCommandPatternReplace() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
runJobCli.run(
|
||||
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||
"hdfs://output",
|
||||
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
|
||||
"python run-job.py --input=%input_path% " +
|
||||
"--model_dir=%checkpoint_path% " +
|
||||
"--export_dir=%saved_model_path%/savedmodel",
|
||||
"--worker_resources", "memory=2048,vcores=2", "--ps_resources",
|
||||
"memory=4096,vcores=4", "--tensorboard", "true", "--ps_launch_cmd",
|
||||
"python run-ps.py --input=%input_path% " +
|
||||
"--model_dir=%checkpoint_path%/model",
|
||||
"--verbose" });
|
||||
|
||||
assertEquals(
|
||||
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
|
||||
+ "--export_dir=hdfs://output/savedmodel",
|
||||
runJobCli.getRunJobParameters().getWorkerLaunchCmd());
|
||||
assertEquals(
|
||||
"python run-ps.py --input=hdfs://input --model_dir=hdfs://output/model",
|
||||
runJobCli.getRunJobParameters().getPSLaunchCmd());
|
||||
}
|
||||
}
|
|
@ -17,7 +17,7 @@
|
|||
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
|
||||
import org.yaml.snakeyaml.Yaml;
|
||||
import org.yaml.snakeyaml.constructor.Constructor;
|
||||
|
@ -33,13 +33,13 @@ public final class YamlConfigTestUtils {
|
|||
|
||||
private YamlConfigTestUtils() {}
|
||||
|
||||
static void deleteFile(File file) {
|
||||
public static void deleteFile(File file) {
|
||||
if (file != null) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
static YamlConfigFile readYamlConfigFile(String filename) {
|
||||
public static YamlConfigFile readYamlConfigFile(String filename) {
|
||||
Constructor constructor = new Constructor(YamlConfigFile.class);
|
||||
constructor.setPropertyUtils(new UnderscoreConverterPropertyUtils());
|
||||
Yaml yaml = new Yaml(constructor);
|
||||
|
@ -49,7 +49,8 @@ public final class YamlConfigTestUtils {
|
|||
return yaml.loadAs(inputStream, YamlConfigFile.class);
|
||||
}
|
||||
|
||||
static File createTempFileWithContents(String filename) throws IOException {
|
||||
public static File createTempFileWithContents(String filename)
|
||||
throws IOException {
|
||||
InputStream inputStream = YamlConfigTestUtils.class
|
||||
.getClassLoader()
|
||||
.getResourceAsStream(filename);
|
||||
|
@ -58,7 +59,7 @@ public final class YamlConfigTestUtils {
|
|||
return targetFile;
|
||||
}
|
||||
|
||||
static File createEmptyTempFile() throws IOException {
|
||||
public static File createEmptyTempFile() throws IOException {
|
||||
return File.createTempFile("test", ".yaml");
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||
|
||||
import org.apache.commons.cli.MissingArgumentException;
|
||||
import org.apache.commons.cli.ParseException;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
import java.io.IOException;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
/**
|
||||
* This class contains some test methods to test common functionality
|
||||
* (including TF / PyTorch) of the run job Submarine command.
|
||||
*/
|
||||
public class TestRunJobCliParsingCommon {
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
SubmarineLogs.verboseOff();
|
||||
}
|
||||
|
||||
@Rule
|
||||
public ExpectedException expectedException = ExpectedException.none();
|
||||
|
||||
public static MockClientContext getMockClientContext()
|
||||
throws IOException, YarnException {
|
||||
MockClientContext mockClientContext = new MockClientContext();
|
||||
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
||||
when(mockJobSubmitter.submitJob(any(ParametersHolder.class)))
|
||||
.thenReturn(ApplicationId.newInstance(1235L, 1));
|
||||
|
||||
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
||||
SubmarineStorage storage = mock(SubmarineStorage.class);
|
||||
RuntimeFactory rtFactory = mock(RuntimeFactory.class);
|
||||
|
||||
when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
|
||||
when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
|
||||
when(rtFactory.getSubmarineStorage()).thenReturn(storage);
|
||||
|
||||
mockClientContext.setRuntimeFactory(rtFactory);
|
||||
return mockClientContext;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAbsentFrameworkFallsBackToTensorFlow() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
runJobCli.run(
|
||||
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||
"hdfs://output",
|
||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||
"true", "--verbose", "--wait_job_finish"});
|
||||
RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
|
||||
assertTrue("Default Framework should be TensorFlow!",
|
||||
runJobParameters instanceof TensorFlowRunJobParameters);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyFrameworkOption() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
expectedException.expect(MissingArgumentException.class);
|
||||
expectedException.expectMessage("Missing argument for option: framework");
|
||||
|
||||
runJobCli.run(
|
||||
new String[]{"--framework", "--name", "my-job",
|
||||
"--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||
"hdfs://output",
|
||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||
"true", "--verbose", "--wait_job_finish"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidFrameworkOption() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
expectedException.expect(ParseException.class);
|
||||
expectedException.expectMessage("Failed to parse Framework type");
|
||||
|
||||
runJobCli.run(
|
||||
new String[]{"--framework", "bla", "--name", "my-job",
|
||||
"--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||
"hdfs://output",
|
||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||
"true", "--verbose", "--wait_job_finish"});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,252 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||
|
||||
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
||||
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* This class contains some test methods to test common YAML parsing
|
||||
* functionality (including TF / PyTorch) of the run job Submarine command.
|
||||
*/
|
||||
public class TestRunJobCliParsingCommonYaml {
|
||||
private static final String DIR_NAME = "runjob-common-yaml";
|
||||
private static final String TF_DIR = "runjob-pytorch-yaml";
|
||||
private File yamlConfig;
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
SubmarineLogs.verboseOff();
|
||||
}
|
||||
|
||||
@After
|
||||
public void after() {
|
||||
YamlConfigTestUtils.deleteFile(yamlConfig);
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void configureResourceTypes() {
|
||||
List<ResourceTypeInfo> resTypes = new ArrayList<>(
|
||||
ResourceUtils.getResourcesTypeInfo());
|
||||
resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
|
||||
ResourceUtils.reinitializeResources(resTypes);
|
||||
}
|
||||
|
||||
@Rule
|
||||
public ExpectedException exception = ExpectedException.none();
|
||||
|
||||
@Test
|
||||
public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
TF_DIR + "/valid-config.yaml");
|
||||
String[] args = new String[] {"--name", "my-job",
|
||||
"--docker_image", "tf-docker:1.1.0",
|
||||
"-f", yamlConfig.getAbsolutePath() };
|
||||
|
||||
exception.expect(YarnException.class);
|
||||
exception.expectMessage("defined both with YAML config and with " +
|
||||
"CLI argument");
|
||||
|
||||
runJobCli.run(args);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
|
||||
throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
TF_DIR + "/valid-config.yaml");
|
||||
String[] args = new String[] {"--name", "my-job",
|
||||
"--quicklink", "AAA=http://master-0:8321",
|
||||
"--quicklink", "BBB=http://worker-0:1234",
|
||||
"-f", yamlConfig.getAbsolutePath()};
|
||||
|
||||
exception.expect(YarnException.class);
|
||||
exception.expectMessage("defined both with YAML config and with " +
|
||||
"CLI argument");
|
||||
|
||||
runJobCli.run(args);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFalseValuesForBooleanFields() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/test-false-values.yaml");
|
||||
runJobCli.run(
|
||||
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
|
||||
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||
TensorFlowRunJobParameters.class,
|
||||
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) jobRunParameters;
|
||||
|
||||
assertFalse(jobRunParameters.isDistributeKeytab());
|
||||
assertFalse(jobRunParameters.isWaitJobFinish());
|
||||
assertFalse(tensorFlowParams.isTensorboardEnabled());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWrongIndentation() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/wrong-indentation.yaml");
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("Failed to parse YAML config, details:");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWrongFilename() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
runJobCli.run(
|
||||
new String[]{"-f", "not-existing", "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyFile() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotExistingFile() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("file does not exist");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", "blabla", "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWrongPropertyName() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/wrong-property-name.yaml");
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("Failed to parse YAML config, details:");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingConfigsSection() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/missing-configs.yaml");
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("config section should be defined, " +
|
||||
"but it cannot be found");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingSectionsShouldParsed() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/some-sections-missing.yaml");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testAbsentFramework() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/missing-framework.yaml");
|
||||
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyFramework() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/empty-framework.yaml");
|
||||
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidFramework() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/invalid-framework.yaml");
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("framework should is defined, " +
|
||||
"but it has an invalid value");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,192 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.commons.cli.ParseException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* This class contains some test methods to test common CLI parsing
|
||||
* functionality (including TF / PyTorch) of the run job Submarine command.
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestRunJobCliParsingParameterized {
|
||||
|
||||
private final Framework framework;
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
SubmarineLogs.verboseOff();
|
||||
}
|
||||
|
||||
@Rule
|
||||
public ExpectedException expectedException = ExpectedException.none();
|
||||
|
||||
@Parameterized.Parameters
|
||||
public static Collection<Object[]> data() {
|
||||
Collection<Object[]> params = new ArrayList<>();
|
||||
params.add(new Object[]{Framework.TENSORFLOW });
|
||||
params.add(new Object[]{Framework.PYTORCH });
|
||||
return params;
|
||||
}
|
||||
|
||||
public TestRunJobCliParsingParameterized(Framework framework) {
|
||||
this.framework = framework;
|
||||
}
|
||||
|
||||
private String getFrameworkName() {
|
||||
return framework.getValue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPrintHelp() {
|
||||
MockClientContext mockClientContext = new MockClientContext();
|
||||
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
||||
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
||||
RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
|
||||
mockJobMonitor);
|
||||
runJobCli.printUsages();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoInputPathOptionSpecified() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\"" +
|
||||
" is absent";
|
||||
String actualMessage = "";
|
||||
try {
|
||||
runJobCli.run(
|
||||
new String[]{"--framework", getFrameworkName(),
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||
"--worker_resources", "memory=4g,vcores=2", "--verbose",
|
||||
"--wait_job_finish"});
|
||||
} catch (ParseException e) {
|
||||
actualMessage = e.getMessage();
|
||||
e.printStackTrace();
|
||||
}
|
||||
assertEquals(expectedErrorMessage, actualMessage);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJobWithoutName() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
String expectedErrorMessage =
|
||||
"--" + CliConstants.NAME + " is absent";
|
||||
String actualMessage = "";
|
||||
try {
|
||||
runJobCli.run(
|
||||
new String[]{"--framework", getFrameworkName(),
|
||||
"--docker_image", "tf-docker:1.1.0",
|
||||
"--num_workers", "0", "--verbose"});
|
||||
} catch (ParseException e) {
|
||||
actualMessage = e.getMessage();
|
||||
e.printStackTrace();
|
||||
}
|
||||
assertEquals(expectedErrorMessage, actualMessage);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLaunchCommandPatternReplace() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
List<String> parameters = Lists.newArrayList("--framework",
|
||||
getFrameworkName(),
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||
"hdfs://output",
|
||||
"--num_workers", "3",
|
||||
"--worker_launch_cmd", "python run-job.py --input=%input_path% " +
|
||||
"--model_dir=%checkpoint_path% " +
|
||||
"--export_dir=%saved_model_path%/savedmodel",
|
||||
"--worker_resources", "memory=2048,vcores=2");
|
||||
|
||||
if (framework == Framework.TENSORFLOW) {
|
||||
parameters.addAll(Lists.newArrayList(
|
||||
"--ps_resources", "memory=4096,vcores=4",
|
||||
"--ps_launch_cmd", "python run-ps.py --input=%input_path% " +
|
||||
"--model_dir=%checkpoint_path%/model",
|
||||
"--verbose"));
|
||||
}
|
||||
|
||||
runJobCli.run(parameters.toArray(new String[0]));
|
||||
|
||||
RunJobParameters runJobParameters = checkExpectedFrameworkParams(runJobCli);
|
||||
|
||||
if (framework == Framework.TENSORFLOW) {
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) runJobParameters;
|
||||
assertEquals(
|
||||
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
|
||||
+ "--export_dir=hdfs://output/savedmodel",
|
||||
tensorFlowParams.getWorkerLaunchCmd());
|
||||
assertEquals(
|
||||
"python run-ps.py --input=hdfs://input " +
|
||||
"--model_dir=hdfs://output/model",
|
||||
tensorFlowParams.getPSLaunchCmd());
|
||||
} else if (framework == Framework.PYTORCH) {
|
||||
PyTorchRunJobParameters pyTorchParameters =
|
||||
(PyTorchRunJobParameters) runJobParameters;
|
||||
assertEquals(
|
||||
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
|
||||
+ "--export_dir=hdfs://output/savedmodel",
|
||||
pyTorchParameters.getWorkerLaunchCmd());
|
||||
}
|
||||
}
|
||||
|
||||
private RunJobParameters checkExpectedFrameworkParams(RunJobCli runJobCli) {
|
||||
RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
|
||||
|
||||
if (framework == Framework.TENSORFLOW) {
|
||||
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||
TensorFlowRunJobParameters.class,
|
||||
runJobParameters instanceof TensorFlowRunJobParameters);
|
||||
} else if (framework == Framework.PYTORCH) {
|
||||
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||
PyTorchRunJobParameters.class,
|
||||
runJobParameters instanceof PyTorchRunJobParameters);
|
||||
}
|
||||
return runJobParameters;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,209 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
|
||||
|
||||
import org.apache.commons.cli.ParseException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.util.resource.Resources;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* Test class that verifies the correctness of PyTorch
|
||||
* CLI configuration parsing.
|
||||
*/
|
||||
public class TestRunJobCliParsingPyTorch {
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
SubmarineLogs.verboseOff();
|
||||
}
|
||||
|
||||
@Rule
|
||||
public ExpectedException expectedException = ExpectedException.none();
|
||||
|
||||
@Test
|
||||
public void testBasicRunJobForSingleNodeTraining() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
runJobCli.run(
|
||||
new String[] {"--framework", "pytorch",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||
"hdfs://output",
|
||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||
"--worker_resources", "memory=4g,vcores=2", "--verbose",
|
||||
"--wait_job_finish" });
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
assertTrue(RunJobParameters.class +
|
||||
" must be an instance of " +
|
||||
PyTorchRunJobParameters.class,
|
||||
jobRunParameters instanceof PyTorchRunJobParameters);
|
||||
PyTorchRunJobParameters pyTorchParams =
|
||||
(PyTorchRunJobParameters) jobRunParameters;
|
||||
|
||||
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
||||
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
||||
assertEquals(pyTorchParams.getNumWorkers(), 1);
|
||||
assertEquals(pyTorchParams.getWorkerLaunchCmd(),
|
||||
"python run-job.py");
|
||||
assertEquals(Resources.createResource(4096, 2),
|
||||
pyTorchParams.getWorkerResource());
|
||||
assertTrue(SubmarineLogs.isVerbose());
|
||||
assertTrue(jobRunParameters.isWaitJobFinish());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNumPSCannotBeDefined() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
expectedException.expect(ParseException.class);
|
||||
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||
runJobCli.run(
|
||||
new String[] {"--framework", "pytorch",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input",
|
||||
"--checkpoint_path","hdfs://output",
|
||||
"--num_workers", "3",
|
||||
"--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--num_ps", "2" });
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPSResourcesCannotBeDefined() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
expectedException.expect(ParseException.class);
|
||||
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||
runJobCli.run(
|
||||
new String[] {"--framework", "pytorch",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "3",
|
||||
"--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--ps_resources", "memory=2048M,vcores=2" });
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPSDockerImageCannotBeDefined() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
expectedException.expect(ParseException.class);
|
||||
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||
runJobCli.run(
|
||||
new String[] {"--framework", "pytorch",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "3",
|
||||
"--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--ps_docker_image", "psDockerImage" });
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPSLaunchCommandCannotBeDefined() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
expectedException.expect(ParseException.class);
|
||||
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||
runJobCli.run(
|
||||
new String[] {"--framework", "pytorch",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "3",
|
||||
"--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--ps_launch_cmd", "psLaunchCommand" });
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorboardCannotBeDefined() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
expectedException.expect(ParseException.class);
|
||||
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||
runJobCli.run(
|
||||
new String[] {"--framework", "pytorch",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "3",
|
||||
"--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--tensorboard" });
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorboardResourcesCannotBeDefined() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
expectedException.expect(ParseException.class);
|
||||
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||
runJobCli.run(
|
||||
new String[] {"--framework", "pytorch",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "3",
|
||||
"--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--tensorboard_resources", "memory=2048M,vcores=2" });
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTensorboardDockerImageCannotBeDefined() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
expectedException.expect(ParseException.class);
|
||||
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||
runJobCli.run(
|
||||
new String[] {"--framework", "pytorch",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "3",
|
||||
"--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--tensorboard_docker_image", "TBDockerImage" });
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
||||
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
|
||||
import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
|
||||
/**
|
||||
* Test class that verifies the correctness of PyTorch
|
||||
* YAML configuration parsing.
|
||||
*/
|
||||
public class TestRunJobCliParsingPyTorchYaml {
|
||||
private static final String OVERRIDDEN_PREFIX = "overridden_";
|
||||
private static final String DIR_NAME = "runjob-pytorch-yaml";
|
||||
private File yamlConfig;
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
SubmarineLogs.verboseOff();
|
||||
}
|
||||
|
||||
@After
|
||||
public void after() {
|
||||
YamlConfigTestUtils.deleteFile(yamlConfig);
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void configureResourceTypes() {
|
||||
List<ResourceTypeInfo> resTypes = new ArrayList<>(
|
||||
ResourceUtils.getResourcesTypeInfo());
|
||||
resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
|
||||
ResourceUtils.reinitializeResources(resTypes);
|
||||
}
|
||||
|
||||
@Rule
|
||||
public ExpectedException exception = ExpectedException.none();
|
||||
|
||||
private void verifyBasicConfigValues(RunJobParameters jobRunParameters) {
|
||||
verifyBasicConfigValues(jobRunParameters,
|
||||
ImmutableList.of("env1=env1Value", "env2=env2Value"));
|
||||
}
|
||||
|
||||
private void verifyBasicConfigValues(RunJobParameters jobRunParameters,
|
||||
List<String> expectedEnvs) {
|
||||
assertEquals("testInputPath", jobRunParameters.getInputPath());
|
||||
assertEquals("testCheckpointPath", jobRunParameters.getCheckpointPath());
|
||||
assertEquals("testDockerImage", jobRunParameters.getDockerImageName());
|
||||
|
||||
assertNotNull(jobRunParameters.getLocalizations());
|
||||
assertEquals(2, jobRunParameters.getLocalizations().size());
|
||||
|
||||
assertNotNull(jobRunParameters.getQuicklinks());
|
||||
assertEquals(2, jobRunParameters.getQuicklinks().size());
|
||||
|
||||
assertTrue(SubmarineLogs.isVerbose());
|
||||
assertTrue(jobRunParameters.isWaitJobFinish());
|
||||
|
||||
for (String env : expectedEnvs) {
|
||||
assertTrue(String.format(
|
||||
"%s should be in env list of jobRunParameters!", env),
|
||||
jobRunParameters.getEnvars().contains(env));
|
||||
}
|
||||
}
|
||||
|
||||
private void verifyWorkerValues(RunJobParameters jobRunParameters,
|
||||
String prefix) {
|
||||
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||
PyTorchRunJobParameters.class,
|
||||
jobRunParameters instanceof PyTorchRunJobParameters);
|
||||
PyTorchRunJobParameters tensorFlowParams =
|
||||
(PyTorchRunJobParameters) jobRunParameters;
|
||||
|
||||
assertEquals(3, tensorFlowParams.getNumWorkers());
|
||||
assertEquals(prefix + "testLaunchCmdWorker",
|
||||
tensorFlowParams.getWorkerLaunchCmd());
|
||||
assertEquals(prefix + "testDockerImageWorker",
|
||||
tensorFlowParams.getWorkerDockerImage());
|
||||
assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
|
||||
ImmutableMap.<String, String> builder()
|
||||
.put(ResourceInformation.GPU_URI, "2").build()),
|
||||
tensorFlowParams.getWorkerResource());
|
||||
}
|
||||
|
||||
private void verifySecurityValues(RunJobParameters jobRunParameters) {
|
||||
assertEquals("keytabPath", jobRunParameters.getKeytab());
|
||||
assertEquals("testPrincipal", jobRunParameters.getPrincipal());
|
||||
assertTrue(jobRunParameters.isDistributeKeytab());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testValidYamlParsing() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/valid-config.yaml");
|
||||
runJobCli.run(
|
||||
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
verifyBasicConfigValues(jobRunParameters);
|
||||
verifyWorkerValues(jobRunParameters, "");
|
||||
verifySecurityValues(jobRunParameters);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRoleOverrides() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/valid-config-with-overrides.yaml");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
verifyBasicConfigValues(jobRunParameters);
|
||||
verifyWorkerValues(jobRunParameters, OVERRIDDEN_PREFIX);
|
||||
verifySecurityValues(jobRunParameters);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingPrincipalUnderSecuritySection() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/security-principal-is-missing.yaml");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
verifyBasicConfigValues(jobRunParameters);
|
||||
verifyWorkerValues(jobRunParameters, "");
|
||||
|
||||
//Verify security values
|
||||
assertEquals("keytabPath", jobRunParameters.getKeytab());
|
||||
assertNull("Principal should be null!", jobRunParameters.getPrincipal());
|
||||
assertTrue(jobRunParameters.isDistributeKeytab());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingEnvs() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/envs-are-missing.yaml");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
verifyBasicConfigValues(jobRunParameters, ImmutableList.of());
|
||||
verifyWorkerValues(jobRunParameters, "");
|
||||
verifySecurityValues(jobRunParameters);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidConfigPsSectionIsDefined() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("PS section should not be defined " +
|
||||
"when PyTorch is the selected framework");
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/invalid-config-ps-section.yaml");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidConfigTensorboardSectionIsDefined() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("TensorBoard section should not be defined " +
|
||||
"when PyTorch is the selected framework");
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/invalid-config-tensorboard-section.yaml");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,170 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
|
||||
|
||||
import org.apache.commons.cli.ParseException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.util.resource.Resources;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* Test class that verifies the correctness of TensorFlow
|
||||
* CLI configuration parsing.
|
||||
*/
|
||||
public class TestRunJobCliParsingTensorFlow {
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
SubmarineLogs.verboseOff();
|
||||
}
|
||||
|
||||
@Rule
|
||||
public ExpectedException expectedException = ExpectedException.none();
|
||||
|
||||
@Test
|
||||
public void testNoInputPathOptionSpecified() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH +
|
||||
"\" is absent";
|
||||
String actualMessage = "";
|
||||
try {
|
||||
runJobCli.run(
|
||||
new String[]{"--framework", "tensorflow",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||
"true", "--verbose", "--wait_job_finish"});
|
||||
} catch (ParseException e) {
|
||||
actualMessage = e.getMessage();
|
||||
e.printStackTrace();
|
||||
}
|
||||
assertEquals(expectedErrorMessage, actualMessage);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicRunJobForDistributedTraining() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
runJobCli.run(
|
||||
new String[] { "--framework", "tensorflow",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input",
|
||||
"--checkpoint_path", "hdfs://output",
|
||||
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
|
||||
"--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
|
||||
"--principal", "user/_HOST@domain.com", "--distribute_keytab",
|
||||
"--verbose" });
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
assertTrue(RunJobParameters.class +
|
||||
" must be an instance of " +
|
||||
TensorFlowRunJobParameters.class,
|
||||
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) jobRunParameters;
|
||||
|
||||
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
||||
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
||||
assertEquals(tensorFlowParams.getNumPS(), 2);
|
||||
assertEquals(tensorFlowParams.getPSLaunchCmd(), "python run-ps.py");
|
||||
assertEquals(Resources.createResource(4096, 4),
|
||||
tensorFlowParams.getPsResource());
|
||||
assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
|
||||
"python run-job.py");
|
||||
assertEquals(Resources.createResource(2048, 2),
|
||||
tensorFlowParams.getWorkerResource());
|
||||
assertEquals(jobRunParameters.getDockerImageName(),
|
||||
"tf-docker:1.1.0");
|
||||
assertEquals(jobRunParameters.getKeytab(),
|
||||
"/keytab/path");
|
||||
assertEquals(jobRunParameters.getPrincipal(),
|
||||
"user/_HOST@domain.com");
|
||||
assertTrue(jobRunParameters.isDistributeKeytab());
|
||||
assertTrue(SubmarineLogs.isVerbose());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicRunJobForSingleNodeTraining() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
runJobCli.run(
|
||||
new String[] { "--framework", "tensorflow",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||
"hdfs://output",
|
||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||
"true", "--verbose", "--wait_job_finish" });
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
assertTrue(RunJobParameters.class +
|
||||
" must be an instance of " +
|
||||
TensorFlowRunJobParameters.class,
|
||||
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) jobRunParameters;
|
||||
|
||||
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
||||
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
||||
assertEquals(tensorFlowParams.getNumWorkers(), 1);
|
||||
assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
|
||||
"python run-job.py");
|
||||
assertEquals(Resources.createResource(4096, 2),
|
||||
tensorFlowParams.getWorkerResource());
|
||||
assertTrue(SubmarineLogs.isVerbose());
|
||||
assertTrue(jobRunParameters.isWaitJobFinish());
|
||||
}
|
||||
|
||||
/**
|
||||
* when only run tensorboard, input_path is not needed
|
||||
* */
|
||||
@Test
|
||||
public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
boolean success = true;
|
||||
try {
|
||||
runJobCli.run(
|
||||
new String[]{"--framework", "tensorflow",
|
||||
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
"--num_workers", "0", "--tensorboard", "--verbose",
|
||||
"--tensorboard_resources", "memory=2G,vcores=2",
|
||||
"--tensorboard_docker_image", "tb_docker_image:001"});
|
||||
} catch (ParseException e) {
|
||||
success = false;
|
||||
}
|
||||
assertTrue(success);
|
||||
}
|
||||
}
|
|
@ -15,16 +15,17 @@
|
|||
*/
|
||||
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
||||
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||
import org.junit.After;
|
||||
|
@ -39,19 +40,18 @@ import java.io.File;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.client.cli.TestRunJobCliParsing.getMockClientContext;
|
||||
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* Test class that verifies the correctness of YAML configuration parsing.
|
||||
* Test class that verifies the correctness of TF YAML configuration parsing.
|
||||
*/
|
||||
public class TestRunJobCliParsingYaml {
|
||||
public class TestRunJobCliParsingTensorFlowYaml {
|
||||
private static final String OVERRIDDEN_PREFIX = "overridden_";
|
||||
private static final String DIR_NAME = "runjobcliparsing";
|
||||
private static final String DIR_NAME = "runjob-tensorflow-yaml";
|
||||
private File yamlConfig;
|
||||
|
||||
@Before
|
||||
|
@ -104,27 +104,39 @@ public class TestRunJobCliParsingYaml {
|
|||
|
||||
private void verifyPsValues(RunJobParameters jobRunParameters,
|
||||
String prefix) {
|
||||
assertEquals(4, jobRunParameters.getNumPS());
|
||||
assertEquals(prefix + "testLaunchCmdPs", jobRunParameters.getPSLaunchCmd());
|
||||
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||
TensorFlowRunJobParameters.class,
|
||||
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) jobRunParameters;
|
||||
|
||||
assertEquals(4, tensorFlowParams.getNumPS());
|
||||
assertEquals(prefix + "testLaunchCmdPs", tensorFlowParams.getPSLaunchCmd());
|
||||
assertEquals(prefix + "testDockerImagePs",
|
||||
jobRunParameters.getPsDockerImage());
|
||||
tensorFlowParams.getPsDockerImage());
|
||||
assertEquals(ResourceTypesTestHelper.newResource(20500L, 34,
|
||||
ImmutableMap.<String, String> builder()
|
||||
.put(ResourceInformation.GPU_URI, "4").build()),
|
||||
jobRunParameters.getPsResource());
|
||||
tensorFlowParams.getPsResource());
|
||||
}
|
||||
|
||||
private void verifyWorkerValues(RunJobParameters jobRunParameters,
|
||||
String prefix) {
|
||||
assertEquals(3, jobRunParameters.getNumWorkers());
|
||||
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||
TensorFlowRunJobParameters.class,
|
||||
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) jobRunParameters;
|
||||
|
||||
assertEquals(3, tensorFlowParams.getNumWorkers());
|
||||
assertEquals(prefix + "testLaunchCmdWorker",
|
||||
jobRunParameters.getWorkerLaunchCmd());
|
||||
tensorFlowParams.getWorkerLaunchCmd());
|
||||
assertEquals(prefix + "testDockerImageWorker",
|
||||
jobRunParameters.getWorkerDockerImage());
|
||||
tensorFlowParams.getWorkerDockerImage());
|
||||
assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
|
||||
ImmutableMap.<String, String> builder()
|
||||
.put(ResourceInformation.GPU_URI, "2").build()),
|
||||
jobRunParameters.getWorkerResource());
|
||||
tensorFlowParams.getWorkerResource());
|
||||
}
|
||||
|
||||
private void verifySecurityValues(RunJobParameters jobRunParameters) {
|
||||
|
@ -134,13 +146,19 @@ public class TestRunJobCliParsingYaml {
|
|||
}
|
||||
|
||||
private void verifyTensorboardValues(RunJobParameters jobRunParameters) {
|
||||
assertTrue(jobRunParameters.isTensorboardEnabled());
|
||||
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||
TensorFlowRunJobParameters.class,
|
||||
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) jobRunParameters;
|
||||
|
||||
assertTrue(tensorFlowParams.isTensorboardEnabled());
|
||||
assertEquals("tensorboardDockerImage",
|
||||
jobRunParameters.getTensorboardDockerImage());
|
||||
tensorFlowParams.getTensorboardDockerImage());
|
||||
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
|
||||
ImmutableMap.<String, String> builder()
|
||||
.put(ResourceInformation.GPU_URI, "3").build()),
|
||||
jobRunParameters.getTensorboardResource());
|
||||
tensorFlowParams.getTensorboardResource());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -161,44 +179,6 @@ public class TestRunJobCliParsingYaml {
|
|||
verifyTensorboardValues(jobRunParameters);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/valid-config.yaml");
|
||||
String[] args = new String[] {"--name", "my-job",
|
||||
"--docker_image", "tf-docker:1.1.0",
|
||||
"-f", yamlConfig.getAbsolutePath() };
|
||||
|
||||
exception.expect(YarnException.class);
|
||||
exception.expectMessage("defined both with YAML config and with " +
|
||||
"CLI argument");
|
||||
|
||||
runJobCli.run(args);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
|
||||
throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/valid-config.yaml");
|
||||
String[] args = new String[] {"--name", "my-job",
|
||||
"--quicklink", "AAA=http://master-0:8321",
|
||||
"--quicklink", "BBB=http://worker-0:1234",
|
||||
"-f", yamlConfig.getAbsolutePath()};
|
||||
|
||||
exception.expect(YarnException.class);
|
||||
exception.expectMessage("defined both with YAML config and with " +
|
||||
"CLI argument");
|
||||
|
||||
runJobCli.run(args);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRoleOverrides() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
@ -217,104 +197,6 @@ public class TestRunJobCliParsingYaml {
|
|||
verifyTensorboardValues(jobRunParameters);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFalseValuesForBooleanFields() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/test-false-values.yaml");
|
||||
runJobCli.run(
|
||||
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
|
||||
assertFalse(jobRunParameters.isDistributeKeytab());
|
||||
assertFalse(jobRunParameters.isWaitJobFinish());
|
||||
assertFalse(jobRunParameters.isTensorboardEnabled());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWrongIndentation() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/wrong-indentation.yaml");
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("Failed to parse YAML config, details:");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWrongFilename() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
runJobCli.run(
|
||||
new String[]{"-f", "not-existing", "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyFile() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotExistingFile() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("file does not exist");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", "blabla", "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWrongPropertyName() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/wrong-property-name.yaml");
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("Failed to parse YAML config, details:");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingConfigsSection() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/missing-configs.yaml");
|
||||
|
||||
exception.expect(YamlParseException.class);
|
||||
exception.expectMessage("config section should be defined, " +
|
||||
"but it cannot be found");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingSectionsShouldParsed() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||
DIR_NAME + "/some-sections-missing.yaml");
|
||||
runJobCli.run(
|
||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMissingPrincipalUnderSecuritySection() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
|
@ -346,18 +228,22 @@ public class TestRunJobCliParsingYaml {
|
|||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
|
||||
verifyBasicConfigValues(jobRunParameters);
|
||||
verifyPsValues(jobRunParameters, "");
|
||||
verifyWorkerValues(jobRunParameters, "");
|
||||
verifySecurityValues(jobRunParameters);
|
||||
|
||||
assertTrue(jobRunParameters.isTensorboardEnabled());
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) jobRunParameters;
|
||||
|
||||
assertTrue(tensorFlowParams.isTensorboardEnabled());
|
||||
assertNull("tensorboardDockerImage should be null!",
|
||||
jobRunParameters.getTensorboardDockerImage());
|
||||
tensorFlowParams.getTensorboardDockerImage());
|
||||
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
|
||||
ImmutableMap.<String, String> builder()
|
||||
.put(ResourceInformation.GPU_URI, "3").build()),
|
||||
jobRunParameters.getTensorboardResource());
|
||||
tensorFlowParams.getTensorboardResource());
|
||||
}
|
||||
|
||||
@Test
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
|
||||
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
|
||||
|
@ -42,14 +42,9 @@ import static org.junit.Assert.assertTrue;
|
|||
* Please note that this class just tests YAML parsing,
|
||||
* but only in an isolated fashion.
|
||||
*/
|
||||
public class TestRunJobCliParsingYamlStandalone {
|
||||
public class TestRunJobCliParsingTensorFlowYamlStandalone {
|
||||
private static final String OVERRIDDEN_PREFIX = "overridden_";
|
||||
private static final String DIR_NAME = "runjobcliparsing";
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
SubmarineLogs.verboseOff();
|
||||
}
|
||||
private static final String DIR_NAME = "runjob-tensorflow-yaml";
|
||||
|
||||
private void verifyBasicConfigValues(YamlConfigFile yamlConfigFile) {
|
||||
assertNotNull("Spec file should not be null!", yamlConfigFile);
|
||||
|
@ -169,6 +164,11 @@ public class TestRunJobCliParsingYamlStandalone {
|
|||
assertEquals("memory=21000M,vcores=37,gpu=3", tensorBoard.getResources());
|
||||
}
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
SubmarineLogs.verboseOff();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLaunchCommandYaml() {
|
||||
YamlConfigFile yamlConfigFile = readYamlConfigFile(DIR_NAME +
|
||||
|
@ -201,5 +201,4 @@ public class TestRunJobCliParsingYamlStandalone {
|
|||
assertRoleConfigOverrides(roles.getWorker(), OVERRIDDEN_PREFIX, "Worker");
|
||||
assertRoleConfigOverrides(roles.getPs(), OVERRIDDEN_PREFIX, "Ps");
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework:
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
||||
checkpoint_path: testCheckpointPath
|
||||
saved_model_path: testSavedModelPath
|
||||
docker_image: testDockerImage
|
||||
wait_job_finish: true
|
||||
envs:
|
||||
env1: 'env1Value'
|
||||
env2: 'env2Value'
|
||||
localizations:
|
||||
- hdfs://remote-file1:/local-filename1:rw
|
||||
- nfs://remote-file2:/local-filename2:rw
|
||||
mounts:
|
||||
- /etc/passwd:/etc/passwd:rw
|
||||
- /etc/hosts:/etc/hosts:rw
|
||||
quicklinks:
|
||||
- Notebook_UI=https://master-0:7070
|
||||
- Notebook_UI2=https://master-0:7071
|
||||
|
||||
scheduling:
|
||||
queue: queue1
|
||||
|
||||
roles:
|
||||
worker:
|
||||
resources: memory=20480M,vcores=32,gpu=2
|
||||
replicas: 3
|
||||
launch_cmd: testLaunchCmdWorker
|
||||
docker_image: testDockerImageWorker
|
||||
ps:
|
||||
resources: memory=20500M,vcores=34,gpu=4
|
||||
replicas: 4
|
||||
launch_cmd: testLaunchCmdPs
|
||||
docker_image: testDockerImagePs
|
||||
|
||||
security:
|
||||
keytab: keytabPath
|
||||
principal: testPrincipal
|
||||
distribute_keytab: true
|
||||
|
||||
tensorBoard:
|
||||
resources: memory=21000M,vcores=37,gpu=3
|
||||
docker_image: tensorboardDockerImage
|
|
@ -0,0 +1,63 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: bla
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
||||
checkpoint_path: testCheckpointPath
|
||||
saved_model_path: testSavedModelPath
|
||||
docker_image: testDockerImage
|
||||
wait_job_finish: true
|
||||
envs:
|
||||
env1: 'env1Value'
|
||||
env2: 'env2Value'
|
||||
localizations:
|
||||
- hdfs://remote-file1:/local-filename1:rw
|
||||
- nfs://remote-file2:/local-filename2:rw
|
||||
mounts:
|
||||
- /etc/passwd:/etc/passwd:rw
|
||||
- /etc/hosts:/etc/hosts:rw
|
||||
quicklinks:
|
||||
- Notebook_UI=https://master-0:7070
|
||||
- Notebook_UI2=https://master-0:7071
|
||||
|
||||
scheduling:
|
||||
queue: queue1
|
||||
|
||||
roles:
|
||||
worker:
|
||||
resources: memory=20480M,vcores=32,gpu=2
|
||||
replicas: 3
|
||||
launch_cmd: testLaunchCmdWorker
|
||||
docker_image: testDockerImageWorker
|
||||
ps:
|
||||
resources: memory=20500M,vcores=34,gpu=4
|
||||
replicas: 4
|
||||
launch_cmd: testLaunchCmdPs
|
||||
docker_image: testDockerImagePs
|
||||
|
||||
security:
|
||||
keytab: keytabPath
|
||||
principal: testPrincipal
|
||||
distribute_keytab: true
|
||||
|
||||
tensorBoard:
|
||||
resources: memory=21000M,vcores=37,gpu=3
|
||||
docker_image: tensorboardDockerImage
|
|
@ -17,6 +17,7 @@
|
|||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: tensorflow
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
|
@ -17,6 +17,7 @@
|
|||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: tensorflow
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
|
@ -0,0 +1,51 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: pytorch
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
||||
checkpoint_path: testCheckpointPath
|
||||
saved_model_path: testSavedModelPath
|
||||
docker_image: testDockerImage
|
||||
wait_job_finish: true
|
||||
localizations:
|
||||
- hdfs://remote-file1:/local-filename1:rw
|
||||
- nfs://remote-file2:/local-filename2:rw
|
||||
mounts:
|
||||
- /etc/passwd:/etc/passwd:rw
|
||||
- /etc/hosts:/etc/hosts:rw
|
||||
quicklinks:
|
||||
- Notebook_UI=https://master-0:7070
|
||||
- Notebook_UI2=https://master-0:7071
|
||||
|
||||
scheduling:
|
||||
queue: queue1
|
||||
|
||||
roles:
|
||||
worker:
|
||||
resources: memory=20480M,vcores=32,gpu=2
|
||||
replicas: 3
|
||||
launch_cmd: testLaunchCmdWorker
|
||||
docker_image: testDockerImageWorker
|
||||
|
||||
security:
|
||||
keytab: keytabPath
|
||||
principal: testPrincipal
|
||||
distribute_keytab: true
|
|
@ -0,0 +1,56 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: pytorch
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
||||
checkpoint_path: testCheckpointPath
|
||||
saved_model_path: testSavedModelPath
|
||||
docker_image: testDockerImage
|
||||
wait_job_finish: true
|
||||
envs:
|
||||
env1: 'env1Value'
|
||||
env2: 'env2Value'
|
||||
localizations:
|
||||
- hdfs://remote-file1:/local-filename1:rw
|
||||
- nfs://remote-file2:/local-filename2:rw
|
||||
mounts:
|
||||
- /etc/passwd:/etc/passwd:rw
|
||||
- /etc/hosts:/etc/hosts:rw
|
||||
quicklinks:
|
||||
- Notebook_UI=https://master-0:7070
|
||||
- Notebook_UI2=https://master-0:7071
|
||||
|
||||
scheduling:
|
||||
queue: queue1
|
||||
|
||||
roles:
|
||||
worker:
|
||||
resources: memory=20480M,vcores=32,gpu=2
|
||||
replicas: 3
|
||||
launch_cmd: testLaunchCmdWorker
|
||||
docker_image: testDockerImageWorker
|
||||
ps:
|
||||
docker_image: testPSDockerImage
|
||||
|
||||
security:
|
||||
keytab: keytabPath
|
||||
principal: testPrincipal
|
||||
distribute_keytab: true
|
|
@ -0,0 +1,57 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: pytorch
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
||||
checkpoint_path: testCheckpointPath
|
||||
saved_model_path: testSavedModelPath
|
||||
docker_image: testDockerImage
|
||||
wait_job_finish: true
|
||||
envs:
|
||||
env1: 'env1Value'
|
||||
env2: 'env2Value'
|
||||
localizations:
|
||||
- hdfs://remote-file1:/local-filename1:rw
|
||||
- nfs://remote-file2:/local-filename2:rw
|
||||
mounts:
|
||||
- /etc/passwd:/etc/passwd:rw
|
||||
- /etc/hosts:/etc/hosts:rw
|
||||
quicklinks:
|
||||
- Notebook_UI=https://master-0:7070
|
||||
- Notebook_UI2=https://master-0:7071
|
||||
|
||||
scheduling:
|
||||
queue: queue1
|
||||
|
||||
roles:
|
||||
worker:
|
||||
resources: memory=20480M,vcores=32,gpu=2
|
||||
replicas: 3
|
||||
launch_cmd: testLaunchCmdWorker
|
||||
docker_image: testDockerImageWorker
|
||||
|
||||
security:
|
||||
keytab: keytabPath
|
||||
principal: testPrincipal
|
||||
distribute_keytab: true
|
||||
|
||||
tensorBoard:
|
||||
docker_image: tensorboardDockerImage
|
|
@ -0,0 +1,53 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: pytorch
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
||||
checkpoint_path: testCheckpointPath
|
||||
saved_model_path: testSavedModelPath
|
||||
docker_image: testDockerImage
|
||||
wait_job_finish: true
|
||||
envs:
|
||||
env1: 'env1Value'
|
||||
env2: 'env2Value'
|
||||
localizations:
|
||||
- hdfs://remote-file1:/local-filename1:rw
|
||||
- nfs://remote-file2:/local-filename2:rw
|
||||
mounts:
|
||||
- /etc/passwd:/etc/passwd:rw
|
||||
- /etc/hosts:/etc/hosts:rw
|
||||
quicklinks:
|
||||
- Notebook_UI=https://master-0:7070
|
||||
- Notebook_UI2=https://master-0:7071
|
||||
|
||||
scheduling:
|
||||
queue: queue1
|
||||
|
||||
roles:
|
||||
worker:
|
||||
resources: memory=20480M,vcores=32,gpu=2
|
||||
replicas: 3
|
||||
launch_cmd: testLaunchCmdWorker
|
||||
docker_image: testDockerImageWorker
|
||||
|
||||
security:
|
||||
keytab: keytabPath
|
||||
distribute_keytab: true
|
|
@ -0,0 +1,63 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: pytorch
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
||||
checkpoint_path: testCheckpointPath
|
||||
saved_model_path: testSavedModelPath
|
||||
docker_image: testDockerImage
|
||||
wait_job_finish: true
|
||||
envs:
|
||||
env1: 'env1Value'
|
||||
env2: 'env2Value'
|
||||
localizations:
|
||||
- hdfs://remote-file1:/local-filename1:rw
|
||||
- nfs://remote-file2:/local-filename2:rw
|
||||
mounts:
|
||||
- /etc/passwd:/etc/passwd:rw
|
||||
- /etc/hosts:/etc/hosts:rw
|
||||
quicklinks:
|
||||
- Notebook_UI=https://master-0:7070
|
||||
- Notebook_UI2=https://master-0:7071
|
||||
|
||||
scheduling:
|
||||
queue: queue1
|
||||
|
||||
roles:
|
||||
worker:
|
||||
resources: memory=20480M,vcores=32,gpu=2
|
||||
replicas: 3
|
||||
launch_cmd: overridden_testLaunchCmdWorker
|
||||
docker_image: overridden_testDockerImageWorker
|
||||
envs:
|
||||
env1: 'overridden_env1Worker'
|
||||
env2: 'overridden_env2Worker'
|
||||
localizations:
|
||||
- hdfs://remote-file1:/overridden_local-filename1Worker:rw
|
||||
- nfs://remote-file2:/overridden_local-filename2Worker:rw
|
||||
mounts:
|
||||
- /etc/passwd:/overridden_Worker
|
||||
- /etc/hosts:/overridden_Worker
|
||||
|
||||
security:
|
||||
keytab: keytabPath
|
||||
principal: testPrincipal
|
||||
distribute_keytab: true
|
|
@ -0,0 +1,54 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: pytorch
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
||||
checkpoint_path: testCheckpointPath
|
||||
saved_model_path: testSavedModelPath
|
||||
docker_image: testDockerImage
|
||||
wait_job_finish: true
|
||||
envs:
|
||||
env1: 'env1Value'
|
||||
env2: 'env2Value'
|
||||
localizations:
|
||||
- hdfs://remote-file1:/local-filename1:rw
|
||||
- nfs://remote-file2:/local-filename2:rw
|
||||
mounts:
|
||||
- /etc/passwd:/etc/passwd:rw
|
||||
- /etc/hosts:/etc/hosts:rw
|
||||
quicklinks:
|
||||
- Notebook_UI=https://master-0:7070
|
||||
- Notebook_UI2=https://master-0:7071
|
||||
|
||||
scheduling:
|
||||
queue: queue1
|
||||
|
||||
roles:
|
||||
worker:
|
||||
resources: memory=20480M,vcores=32,gpu=2
|
||||
replicas: 3
|
||||
launch_cmd: testLaunchCmdWorker
|
||||
docker_image: testDockerImageWorker
|
||||
|
||||
security:
|
||||
keytab: keytabPath
|
||||
principal: testPrincipal
|
||||
distribute_keytab: true
|
|
@ -17,6 +17,7 @@
|
|||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: tensorflow
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
|
@ -17,6 +17,7 @@
|
|||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: tensorflow
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
|
@ -17,6 +17,7 @@
|
|||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: tensorflow
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
|
@ -17,6 +17,7 @@
|
|||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: tensorflow
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
|
@ -0,0 +1,63 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
spec:
|
||||
name: testJobName
|
||||
job_type: testJobType
|
||||
framework: tensorflow
|
||||
|
||||
configs:
|
||||
input_path: testInputPath
|
||||
checkpoint_path: testCheckpointPath
|
||||
saved_model_path: testSavedModelPath
|
||||
docker_image: testDockerImage
|
||||
wait_job_finish: true
|
||||
envs:
|
||||
env1: 'env1Value'
|
||||
env2: 'env2Value'
|
||||
localizations:
|
||||
- hdfs://remote-file1:/local-filename1:rw
|
||||
- nfs://remote-file2:/local-filename2:rw
|
||||
mounts:
|
||||
- /etc/passwd:/etc/passwd:rw
|
||||
- /etc/hosts:/etc/hosts:rw
|
||||
quicklinks:
|
||||
- Notebook_UI=https://master-0:7070
|
||||
- Notebook_UI2=https://master-0:7071
|
||||
|
||||
scheduling:
|
||||
queue: queue1
|
||||
|
||||
roles:
|
||||
worker:
|
||||
resources: memory=20480M,vcores=32,gpu=2
|
||||
replicas: 3
|
||||
launch_cmd: testLaunchCmdWorker
|
||||
docker_image: testDockerImageWorker
|
||||
ps:
|
||||
resources: memory=20500M,vcores=34,gpu=4
|
||||
replicas: 4
|
||||
launch_cmd: testLaunchCmdPs
|
||||
docker_image: testDockerImagePs
|
||||
|
||||
security:
|
||||
keytab: keytabPath
|
||||
principal: testPrincipal
|
||||
distribute_keytab: true
|
||||
|
||||
tensorBoard:
|
||||
resources: memory=21000M,vcores=37,gpu=3
|
||||
docker_image: tensorboardDockerImage
|
|
@ -22,7 +22,9 @@ import org.apache.commons.logging.LogFactory;
|
|||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -45,14 +47,24 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
|
|||
}
|
||||
|
||||
@Override
|
||||
public ApplicationId submitJob(RunJobParameters parameters)
|
||||
throws IOException, YarnException {
|
||||
public ApplicationId submitJob(ParametersHolder parameters)
|
||||
throws IOException {
|
||||
if (parameters.getFramework() == Framework.PYTORCH) {
|
||||
// we need to throw an exception, as ParametersHolder's parameters field
|
||||
// could not be casted to TensorFlowRunJobParameters.
|
||||
throw new UnsupportedOperationException(
|
||||
"Support \"–-framework\" option for PyTorch in Tony is coming. " +
|
||||
"Please check the documentation about how to submit a " +
|
||||
"PyTorch job with TonY runtime.");
|
||||
}
|
||||
|
||||
LOG.info("Starting Tony runtime..");
|
||||
|
||||
File tonyFinalConfPath = File.createTempFile("temp",
|
||||
Constants.TONY_FINAL_XML);
|
||||
// Write user's overridden conf to an xml to be localized.
|
||||
Configuration tonyConf = TonyUtils.tonyConfFromClientContext(parameters);
|
||||
Configuration tonyConf = TonyUtils.tonyConfFromClientContext(
|
||||
(TensorFlowRunJobParameters) parameters.getParameters());
|
||||
try (OutputStream os = new FileOutputStream(tonyFinalConfPath)) {
|
||||
tonyConf.writeXml(os);
|
||||
} catch (IOException e) {
|
||||
|
@ -68,7 +80,7 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
|
|||
LOG.error("Failed to init TonyClient: ", e);
|
||||
}
|
||||
Thread clientThread = new Thread(tonyClient::start);
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
java.lang.Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
try {
|
||||
tonyClient.forceKillApplication();
|
||||
} catch (YarnException | IOException e) {
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
|
|||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
||||
import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -35,7 +35,7 @@ public final class TonyUtils {
|
|||
private static final Log LOG = LogFactory.getLog(TonyUtils.class);
|
||||
|
||||
public static Configuration tonyConfFromClientContext(
|
||||
RunJobParameters parameters) {
|
||||
TensorFlowRunJobParameters parameters) {
|
||||
Configuration tonyConf = new Configuration();
|
||||
tonyConf.setInt(
|
||||
TonyConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME),
|
||||
|
|
|
@ -147,6 +147,7 @@ CLASSPATH=$(hadoop classpath --glob): \
|
|||
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
||||
|
||||
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
||||
--framework tensorflow \
|
||||
--num_workers 2 \
|
||||
--worker_resources memory=3G,vcores=2 \
|
||||
--num_ps 2 \
|
||||
|
@ -183,6 +184,7 @@ CLASSPATH=$(hadoop classpath --glob): \
|
|||
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
||||
|
||||
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
||||
--framework tensorflow \
|
||||
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
|
||||
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
|
||||
--worker_resources memory=3G,vcores=2 \
|
||||
|
@ -245,6 +247,7 @@ CLASSPATH=$(hadoop classpath --glob): \
|
|||
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
||||
|
||||
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
||||
--framework tensorflow \
|
||||
--num_workers 2 \
|
||||
--worker_resources memory=3G,vcores=2 \
|
||||
--num_ps 2 \
|
||||
|
@ -281,6 +284,7 @@ CLASSPATH=$(hadoop classpath --glob): \
|
|||
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
||||
|
||||
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
||||
--framework tensorflow \
|
||||
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
|
||||
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
|
||||
--worker_resources memory=3G,vcores=2 \
|
||||
|
|
|
@ -16,8 +16,10 @@ import com.linkedin.tony.TonyConfigurationKeys;
|
|||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
||||
|
@ -31,6 +33,7 @@ import org.junit.Test;
|
|||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
@ -59,7 +62,8 @@ public class TestTonyUtils {
|
|||
throws IOException, YarnException {
|
||||
MockClientContext mockClientContext = new MockClientContext();
|
||||
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
||||
when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
|
||||
when(mockJobSubmitter.submitJob(
|
||||
any(ParametersHolder.class))).thenReturn(
|
||||
ApplicationId.newInstance(1234L, 1));
|
||||
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
||||
SubmarineStorage storage = mock(SubmarineStorage.class);
|
||||
|
@ -82,20 +86,28 @@ public class TestTonyUtils {
|
|||
public void testTonyConfFromClientContext() throws Exception {
|
||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||
runJobCli.run(
|
||||
new String[] {"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||
new String[] {"--framework", "tensorflow", "--name", "my-job",
|
||||
"--docker_image", "tf-docker:1.1.0",
|
||||
"--input_path", "hdfs://input",
|
||||
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
|
||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||
"--ps_resources", "memory=4G,vcores=4", "--ps_launch_cmd",
|
||||
"python run-ps.py"});
|
||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||
|
||||
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||
TensorFlowRunJobParameters.class,
|
||||
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) jobRunParameters;
|
||||
|
||||
Configuration tonyConf = TonyUtils
|
||||
.tonyConfFromClientContext(jobRunParameters);
|
||||
.tonyConfFromClientContext(tensorFlowParams);
|
||||
Assert.assertEquals(jobRunParameters.getDockerImageName(),
|
||||
tonyConf.get(TonyConfigurationKeys.getContainerDockerKey()));
|
||||
Assert.assertEquals("3", tonyConf.get(TonyConfigurationKeys
|
||||
.getInstancesKey("worker")));
|
||||
Assert.assertEquals(jobRunParameters.getWorkerLaunchCmd(),
|
||||
Assert.assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
|
||||
tonyConf.get(TonyConfigurationKeys
|
||||
.getExecuteCommandKey("worker")));
|
||||
Assert.assertEquals("2048", tonyConf.get(TonyConfigurationKeys
|
||||
|
@ -107,7 +119,7 @@ public class TestTonyUtils {
|
|||
Assert.assertEquals("4", tonyConf.get(TonyConfigurationKeys
|
||||
.getResourceKey(Constants.PS_JOB_NAME,
|
||||
Constants.VCORES)));
|
||||
Assert.assertEquals(jobRunParameters.getPSLaunchCmd(),
|
||||
Assert.assertEquals(tensorFlowParams.getPSLaunchCmd(),
|
||||
tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey("ps")));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,8 +19,10 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
|
|||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
||||
|
@ -28,7 +30,11 @@ import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchComma
|
|||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
|
||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
|
||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
|
||||
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
|
||||
import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
|
||||
|
||||
/**
|
||||
* Abstract base class for Component classes.
|
||||
|
@ -40,7 +46,7 @@ import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.T
|
|||
public abstract class AbstractComponent {
|
||||
private final FileSystemOperations fsOperations;
|
||||
protected final RunJobParameters parameters;
|
||||
protected final TaskType taskType;
|
||||
protected final Role role;
|
||||
private final RemoteDirectoryManager remoteDirectoryManager;
|
||||
protected final Configuration yarnConfig;
|
||||
private final LaunchCommandFactory launchCommandFactory;
|
||||
|
@ -52,19 +58,55 @@ public abstract class AbstractComponent {
|
|||
|
||||
public AbstractComponent(FileSystemOperations fsOperations,
|
||||
RemoteDirectoryManager remoteDirectoryManager,
|
||||
RunJobParameters parameters, TaskType taskType,
|
||||
RunJobParameters parameters, Role role,
|
||||
Configuration yarnConfig,
|
||||
LaunchCommandFactory launchCommandFactory) {
|
||||
this.fsOperations = fsOperations;
|
||||
this.remoteDirectoryManager = remoteDirectoryManager;
|
||||
this.parameters = parameters;
|
||||
this.taskType = taskType;
|
||||
this.role = role;
|
||||
this.launchCommandFactory = launchCommandFactory;
|
||||
this.yarnConfig = yarnConfig;
|
||||
}
|
||||
|
||||
protected abstract Component createComponent() throws IOException;
|
||||
|
||||
protected Component createComponentInternal() throws IOException {
|
||||
Objects.requireNonNull(this.parameters.getWorkerResource(),
|
||||
"Worker resource must not be null!");
|
||||
if (parameters.getNumWorkers() < 1) {
|
||||
throw new IllegalArgumentException(
|
||||
"Number of workers should be at least 1!");
|
||||
}
|
||||
|
||||
Component component = new Component();
|
||||
component.setName(role.getComponentName());
|
||||
|
||||
if (role.equals(TensorFlowRole.PRIMARY_WORKER) ||
|
||||
role.equals(PyTorchRole.PRIMARY_WORKER)) {
|
||||
component.setNumberOfContainers(1L);
|
||||
component.getConfiguration().setProperty(
|
||||
CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
|
||||
} else {
|
||||
component.setNumberOfContainers(
|
||||
(long) parameters.getNumWorkers() - 1);
|
||||
}
|
||||
|
||||
if (parameters.getWorkerDockerImage() != null) {
|
||||
component.setArtifact(
|
||||
getDockerArtifact(parameters.getWorkerDockerImage()));
|
||||
}
|
||||
|
||||
component.setResource(convertYarnResourceToServiceResource(
|
||||
parameters.getWorkerResource()));
|
||||
component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
|
||||
|
||||
addCommonEnvironments(component, role);
|
||||
generateLaunchCommand(component);
|
||||
|
||||
return component;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a command launch script on local disk,
|
||||
* returns path to the script.
|
||||
|
@ -72,7 +114,7 @@ public abstract class AbstractComponent {
|
|||
protected void generateLaunchCommand(Component component)
|
||||
throws IOException {
|
||||
AbstractLaunchCommand launchCommand =
|
||||
launchCommandFactory.createLaunchCommand(taskType, component);
|
||||
launchCommandFactory.createLaunchCommand(role, component);
|
||||
this.localScriptFile = launchCommand.generateLaunchScript();
|
||||
|
||||
String remoteLaunchCommand = uploadLaunchCommand(component);
|
||||
|
@ -86,7 +128,7 @@ public abstract class AbstractComponent {
|
|||
Path stagingDir =
|
||||
remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
|
||||
|
||||
String destScriptFileName = getScriptFileName(taskType);
|
||||
String destScriptFileName = getScriptFileName(role);
|
||||
fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
|
||||
localScriptFile, destScriptFileName, component);
|
||||
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
|
||||
import org.apache.hadoop.yarn.service.api.records.Service;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
|
||||
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
|
||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
|
||||
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
|
||||
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
|
||||
|
||||
/**
|
||||
* Abstract base class that supports creating service specs for Native Service.
|
||||
*/
|
||||
public abstract class AbstractServiceSpec implements ServiceSpec {
|
||||
private static final Logger LOG =
|
||||
LoggerFactory.getLogger(AbstractServiceSpec.class);
|
||||
protected final RunJobParameters parameters;
|
||||
protected final FileSystemOperations fsOperations;
|
||||
private final Localizer localizer;
|
||||
protected final RemoteDirectoryManager remoteDirectoryManager;
|
||||
protected final Configuration yarnConfig;
|
||||
protected final LaunchCommandFactory launchCommandFactory;
|
||||
private final WorkerComponentFactory workerFactory;
|
||||
|
||||
public AbstractServiceSpec(RunJobParameters parameters,
|
||||
ClientContext clientContext, FileSystemOperations fsOperations,
|
||||
LaunchCommandFactory launchCommandFactory,
|
||||
Localizer localizer) {
|
||||
this.parameters = parameters;
|
||||
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
|
||||
this.yarnConfig = clientContext.getYarnConfig();
|
||||
this.fsOperations = fsOperations;
|
||||
this.localizer = localizer;
|
||||
this.launchCommandFactory = launchCommandFactory;
|
||||
this.workerFactory = new WorkerComponentFactory(fsOperations,
|
||||
remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
|
||||
}
|
||||
|
||||
protected ServiceWrapper createServiceSpecWrapper() throws IOException {
|
||||
Service serviceSpec = new Service();
|
||||
serviceSpec.setName(parameters.getName());
|
||||
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
|
||||
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
|
||||
|
||||
KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
|
||||
.create(fsOperations, remoteDirectoryManager, parameters);
|
||||
if (kerberosPrincipal != null) {
|
||||
serviceSpec.setKerberosPrincipal(kerberosPrincipal);
|
||||
}
|
||||
|
||||
handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
|
||||
localizer.handleLocalizations(serviceSpec);
|
||||
return new ServiceWrapper(serviceSpec);
|
||||
}
|
||||
|
||||
|
||||
// Handle worker and primary_worker.
|
||||
protected void addWorkerComponents(ServiceWrapper serviceWrapper,
|
||||
Framework framework)
|
||||
throws IOException {
|
||||
final Role primaryWorkerRole;
|
||||
final Role workerRole;
|
||||
if (framework == Framework.TENSORFLOW) {
|
||||
primaryWorkerRole = TensorFlowRole.PRIMARY_WORKER;
|
||||
workerRole = TensorFlowRole.WORKER;
|
||||
} else {
|
||||
primaryWorkerRole = PyTorchRole.PRIMARY_WORKER;
|
||||
workerRole = PyTorchRole.WORKER;
|
||||
}
|
||||
|
||||
addWorkerComponent(serviceWrapper, primaryWorkerRole, framework);
|
||||
|
||||
if (parameters.getNumWorkers() > 1) {
|
||||
addWorkerComponent(serviceWrapper, workerRole, framework);
|
||||
}
|
||||
}
|
||||
private void addWorkerComponent(ServiceWrapper serviceWrapper,
|
||||
Role role, Framework framework) throws IOException {
|
||||
AbstractComponent component = workerFactory.create(framework, role);
|
||||
serviceWrapper.addComponent(component);
|
||||
}
|
||||
|
||||
protected void handleQuicklinks(Service serviceSpec)
|
||||
throws IOException {
|
||||
List<Quicklink> quicklinks = parameters.getQuicklinks();
|
||||
if (quicklinks != null && !quicklinks.isEmpty()) {
|
||||
for (Quicklink ql : quicklinks) {
|
||||
// Make sure it is a valid instance name
|
||||
String instanceName = ql.getComponentInstanceName();
|
||||
boolean found = false;
|
||||
|
||||
for (Component comp : serviceSpec.getComponents()) {
|
||||
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
|
||||
String possibleInstanceName = comp.getName() + "-" + i;
|
||||
if (possibleInstanceName.equals(instanceName)) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
throw new IOException(
|
||||
"Couldn't find a component instance = " + instanceName
|
||||
+ " while adding quicklink");
|
||||
}
|
||||
|
||||
String link = ql.getProtocol()
|
||||
+ YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
|
||||
getUserName(), getDNSDomain(yarnConfig), ql.getPort());
|
||||
addQuicklink(serviceSpec, ql.getLabel(), link);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected static void addQuicklink(Service serviceSpec, String label,
|
||||
String link) {
|
||||
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
|
||||
if (quicklinks == null) {
|
||||
quicklinks = new HashMap<>();
|
||||
serviceSpec.setQuicklinks(quicklinks);
|
||||
}
|
||||
|
||||
if (SubmarineLogs.isVerbose()) {
|
||||
LOG.info("Added quicklink, " + label + "=" + link);
|
||||
}
|
||||
|
||||
quicklinks.put(label, link);
|
||||
}
|
||||
}
|
|
@ -37,6 +37,7 @@ import java.io.File;
|
|||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
|
@ -195,6 +196,15 @@ public class FileSystemOperations {
|
|||
fs.setPermission(destPath, new FsPermission(permission));
|
||||
}
|
||||
|
||||
public static boolean needHdfs(List<String> stringsToCheck) {
|
||||
for (String content : stringsToCheck) {
|
||||
if (content != null && content.contains("hdfs://")) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static boolean needHdfs(String content) {
|
||||
return content != null && content.contains("hdfs://");
|
||||
}
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
|
||||
|
||||
import org.apache.curator.shaded.com.google.common.collect.ImmutableList;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||
|
@ -28,6 +29,8 @@ import org.slf4j.LoggerFactory;
|
|||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.PrintWriter;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
|
||||
import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
|
||||
|
@ -128,10 +131,22 @@ public class HadoopEnvironmentSetup {
|
|||
}
|
||||
|
||||
private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
|
||||
return needHdfs(parameters.getInputPath()) ||
|
||||
needHdfs(parameters.getPSLaunchCmd()) ||
|
||||
needHdfs(parameters.getWorkerLaunchCmd()) ||
|
||||
hadoopEnv;
|
||||
List<String> launchCommands = parameters.getLaunchCommands();
|
||||
if (launchCommands != null) {
|
||||
launchCommands.removeIf(Objects::isNull);
|
||||
}
|
||||
|
||||
ImmutableList.Builder<String> listBuilder = ImmutableList.builder();
|
||||
|
||||
if (launchCommands != null && !launchCommands.isEmpty()) {
|
||||
listBuilder.addAll(launchCommands);
|
||||
}
|
||||
if (parameters.getInputPath() != null) {
|
||||
listBuilder.add(parameters.getInputPath());
|
||||
}
|
||||
List<String> stringsToCheck = listBuilder.build();
|
||||
|
||||
return needHdfs(stringsToCheck) || hadoopEnv;
|
||||
}
|
||||
|
||||
private void appendHdfsHome(PrintWriter fw, String hdfsHome) {
|
||||
|
|
|
@ -38,7 +38,7 @@ public final class ServiceSpecFileGenerator {
|
|||
"instantiated!");
|
||||
}
|
||||
|
||||
static String generateJson(Service service) throws IOException {
|
||||
public static String generateJson(Service service) throws IOException {
|
||||
File serviceSpecFile = File.createTempFile(service.getName(), ".json");
|
||||
String buffer = jsonSerDeser.toJson(service);
|
||||
Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component.PyTorchWorkerComponent;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
|
||||
|
||||
/**
|
||||
* Factory class that helps creating Native Service components.
|
||||
*/
|
||||
public class WorkerComponentFactory {
|
||||
private final FileSystemOperations fsOperations;
|
||||
private final RemoteDirectoryManager remoteDirectoryManager;
|
||||
private final RunJobParameters parameters;
|
||||
private final LaunchCommandFactory launchCommandFactory;
|
||||
private final Configuration yarnConfig;
|
||||
|
||||
WorkerComponentFactory(FileSystemOperations fsOperations,
|
||||
RemoteDirectoryManager remoteDirectoryManager,
|
||||
RunJobParameters parameters,
|
||||
LaunchCommandFactory launchCommandFactory,
|
||||
Configuration yarnConfig) {
|
||||
this.fsOperations = fsOperations;
|
||||
this.remoteDirectoryManager = remoteDirectoryManager;
|
||||
this.parameters = parameters;
|
||||
this.launchCommandFactory = launchCommandFactory;
|
||||
this.yarnConfig = yarnConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates either a TensorFlow or a PyTorch Native Service component.
|
||||
*/
|
||||
public AbstractComponent create(Framework framework, Role role) {
|
||||
if (framework == Framework.TENSORFLOW) {
|
||||
return new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
|
||||
(TensorFlowRunJobParameters) parameters, role,
|
||||
(TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
|
||||
} else if (framework == Framework.PYTORCH) {
|
||||
return new PyTorchWorkerComponent(fsOperations, remoteDirectoryManager,
|
||||
(PyTorchRunJobParameters) parameters, role,
|
||||
(PyTorchLaunchCommandFactory) launchCommandFactory, yarnConfig);
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Only supported frameworks are: "
|
||||
+ Framework.getValues());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,10 +20,16 @@ import org.apache.hadoop.yarn.client.api.AppAdminClient;
|
|||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||
import org.apache.hadoop.yarn.service.api.records.Service;
|
||||
import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.PyTorchServiceSpec;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
|
||||
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -32,6 +38,7 @@ import org.slf4j.LoggerFactory;
|
|||
import java.io.IOException;
|
||||
|
||||
import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
|
||||
import static org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder.SUPPORTED_FRAMEWORKS_MESSAGE;
|
||||
|
||||
/**
|
||||
* Submit a job to cluster.
|
||||
|
@ -51,14 +58,45 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
|
|||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public ApplicationId submitJob(RunJobParameters parameters)
|
||||
public ApplicationId submitJob(ParametersHolder paramsHolder)
|
||||
throws IOException, YarnException {
|
||||
Framework framework = paramsHolder.getFramework();
|
||||
RunJobParameters parameters =
|
||||
(RunJobParameters) paramsHolder.getParameters();
|
||||
|
||||
if (framework == Framework.TENSORFLOW) {
|
||||
return submitTensorFlowJob((TensorFlowRunJobParameters) parameters);
|
||||
} else if (framework == Framework.PYTORCH) {
|
||||
return submitPyTorchJob((PyTorchRunJobParameters) parameters);
|
||||
} else {
|
||||
throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
private ApplicationId submitTensorFlowJob(
|
||||
TensorFlowRunJobParameters parameters) throws IOException, YarnException {
|
||||
FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
|
||||
HadoopEnvironmentSetup hadoopEnvSetup =
|
||||
new HadoopEnvironmentSetup(clientContext, fsOperations);
|
||||
|
||||
Service serviceSpec = createTensorFlowServiceSpec(parameters,
|
||||
fsOperations, hadoopEnvSetup);
|
||||
return submitJobInternal(serviceSpec);
|
||||
}
|
||||
|
||||
private ApplicationId submitPyTorchJob(PyTorchRunJobParameters parameters)
|
||||
throws IOException, YarnException {
|
||||
FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
|
||||
HadoopEnvironmentSetup hadoopEnvSetup =
|
||||
new HadoopEnvironmentSetup(clientContext, fsOperations);
|
||||
|
||||
Service serviceSpec = createPyTorchServiceSpec(parameters,
|
||||
fsOperations, hadoopEnvSetup);
|
||||
return submitJobInternal(serviceSpec);
|
||||
}
|
||||
|
||||
private ApplicationId submitJobInternal(Service serviceSpec)
|
||||
throws IOException, YarnException {
|
||||
String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
|
||||
|
||||
AppAdminClient appAdminClient =
|
||||
|
@ -70,7 +108,7 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
|
|||
"Fail to launch application with exit code:" + code);
|
||||
}
|
||||
|
||||
String appStatus=appAdminClient.getStatusString(serviceSpec.getName());
|
||||
String appStatus = appAdminClient.getStatusString(serviceSpec.getName());
|
||||
Service app = ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
|
||||
|
||||
// Retry multiple times if applicationId is null
|
||||
|
@ -97,11 +135,12 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
|
|||
return appid;
|
||||
}
|
||||
|
||||
private Service createTensorFlowServiceSpec(RunJobParameters parameters,
|
||||
private Service createTensorFlowServiceSpec(
|
||||
TensorFlowRunJobParameters parameters,
|
||||
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
|
||||
throws IOException {
|
||||
LaunchCommandFactory launchCommandFactory =
|
||||
new LaunchCommandFactory(hadoopEnvSetup, parameters,
|
||||
TensorFlowLaunchCommandFactory launchCommandFactory =
|
||||
new TensorFlowLaunchCommandFactory(hadoopEnvSetup, parameters,
|
||||
clientContext.getYarnConfig());
|
||||
Localizer localizer = new Localizer(fsOperations,
|
||||
clientContext.getRemoteDirectoryManager(), parameters);
|
||||
|
@ -113,6 +152,22 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
|
|||
return serviceWrapper.getService();
|
||||
}
|
||||
|
||||
private Service createPyTorchServiceSpec(PyTorchRunJobParameters parameters,
|
||||
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
|
||||
throws IOException {
|
||||
PyTorchLaunchCommandFactory launchCommandFactory =
|
||||
new PyTorchLaunchCommandFactory(hadoopEnvSetup, parameters,
|
||||
clientContext.getYarnConfig());
|
||||
Localizer localizer = new Localizer(fsOperations,
|
||||
clientContext.getRemoteDirectoryManager(), parameters);
|
||||
PyTorchServiceSpec pyTorchServiceSpec = new PyTorchServiceSpec(
|
||||
parameters, this.clientContext, fsOperations, launchCommandFactory,
|
||||
localizer);
|
||||
|
||||
serviceWrapper = pyTorchServiceSpec.create();
|
||||
return serviceWrapper.getService();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public ServiceWrapper getServiceWrapper() {
|
||||
return serviceWrapper;
|
||||
|
|
|
@ -17,11 +17,9 @@
|
|||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Abstract base class for Launch command implementations for Services.
|
||||
|
@ -32,10 +30,9 @@ public abstract class AbstractLaunchCommand {
|
|||
private final LaunchScriptBuilder builder;
|
||||
|
||||
public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||
TaskType taskType, Component component, RunJobParameters parameters)
|
||||
throws IOException {
|
||||
Objects.requireNonNull(taskType, "TaskType must not be null!");
|
||||
this.builder = new LaunchScriptBuilder(taskType.name(), hadoopEnvSetup,
|
||||
Component component, RunJobParameters parameters,
|
||||
String launchCommandPrefix) throws IOException {
|
||||
this.builder = new LaunchScriptBuilder(launchCommandPrefix, hadoopEnvSetup,
|
||||
parameters, component);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,52 +16,15 @@
|
|||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Simple factory to create instances of {@link AbstractLaunchCommand}
|
||||
* based on the {@link TaskType}.
|
||||
* All dependencies are passed to this factory that could be required
|
||||
* by any implementor of {@link AbstractLaunchCommand}.
|
||||
* Interface for creating launch commands.
|
||||
*/
|
||||
public class LaunchCommandFactory {
|
||||
private final HadoopEnvironmentSetup hadoopEnvSetup;
|
||||
private final RunJobParameters parameters;
|
||||
private final Configuration yarnConfig;
|
||||
|
||||
public LaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||
RunJobParameters parameters, Configuration yarnConfig) {
|
||||
this.hadoopEnvSetup = hadoopEnvSetup;
|
||||
this.parameters = parameters;
|
||||
this.yarnConfig = yarnConfig;
|
||||
}
|
||||
|
||||
public AbstractLaunchCommand createLaunchCommand(TaskType taskType,
|
||||
Component component) throws IOException {
|
||||
Objects.requireNonNull(taskType, "TaskType must not be null!");
|
||||
|
||||
if (taskType == TaskType.WORKER || taskType == TaskType.PRIMARY_WORKER) {
|
||||
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
|
||||
component, parameters, yarnConfig);
|
||||
|
||||
} else if (taskType == TaskType.PS) {
|
||||
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
|
||||
parameters, yarnConfig);
|
||||
|
||||
} else if (taskType == TaskType.TENSORBOARD) {
|
||||
return new TensorBoardLaunchCommand(hadoopEnvSetup, taskType, component,
|
||||
parameters);
|
||||
}
|
||||
throw new IllegalStateException("Unknown task type: " + taskType);
|
||||
}
|
||||
public interface LaunchCommandFactory {
|
||||
AbstractLaunchCommand createLaunchCommand(Role role, Component component)
|
||||
throws IOException;
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
@ -47,10 +47,11 @@ public class LaunchScriptBuilder {
|
|||
private final StringBuilder scriptBuffer;
|
||||
private String launchCommand;
|
||||
|
||||
LaunchScriptBuilder(String namePrefix,
|
||||
LaunchScriptBuilder(String launchScriptPrefix,
|
||||
HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
|
||||
Component component) throws IOException {
|
||||
this.file = File.createTempFile(namePrefix + "-launch-script", ".sh");
|
||||
this.file = File.createTempFile(launchScriptPrefix +
|
||||
"-launch-script", ".sh");
|
||||
this.hadoopEnvSetup = hadoopEnvSetup;
|
||||
this.parameters = parameters;
|
||||
this.component = component;
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command.PyTorchWorkerLaunchCommand;
|
||||
|
||||
/**
|
||||
* Simple factory to create instances of {@link AbstractLaunchCommand}
|
||||
* based on the {@link Role}.
|
||||
* All dependencies are passed to this factory that could be required
|
||||
* by any implementor of {@link AbstractLaunchCommand}.
|
||||
*/
|
||||
public class PyTorchLaunchCommandFactory implements LaunchCommandFactory {
|
||||
private final HadoopEnvironmentSetup hadoopEnvSetup;
|
||||
private final PyTorchRunJobParameters parameters;
|
||||
private final Configuration yarnConfig;
|
||||
|
||||
public PyTorchLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||
PyTorchRunJobParameters parameters, Configuration yarnConfig) {
|
||||
this.hadoopEnvSetup = hadoopEnvSetup;
|
||||
this.parameters = parameters;
|
||||
this.yarnConfig = yarnConfig;
|
||||
}
|
||||
|
||||
public AbstractLaunchCommand createLaunchCommand(Role role,
|
||||
Component component) throws IOException {
|
||||
Objects.requireNonNull(role, "Role must not be null!");
|
||||
|
||||
if (role == PyTorchRole.WORKER ||
|
||||
role == PyTorchRole.PRIMARY_WORKER) {
|
||||
return new PyTorchWorkerLaunchCommand(hadoopEnvSetup, role,
|
||||
component, parameters, yarnConfig);
|
||||
|
||||
} else {
|
||||
throw new IllegalStateException("Unknown task type: " + role);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Simple factory to create instances of {@link AbstractLaunchCommand}
|
||||
* based on the {@link Role}.
|
||||
* All dependencies are passed to this factory that could be required
|
||||
* by any implementor of {@link AbstractLaunchCommand}.
|
||||
*/
|
||||
public class TensorFlowLaunchCommandFactory implements LaunchCommandFactory {
|
||||
private final HadoopEnvironmentSetup hadoopEnvSetup;
|
||||
private final TensorFlowRunJobParameters parameters;
|
||||
private final Configuration yarnConfig;
|
||||
|
||||
public TensorFlowLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||
TensorFlowRunJobParameters parameters, Configuration yarnConfig) {
|
||||
this.hadoopEnvSetup = hadoopEnvSetup;
|
||||
this.parameters = parameters;
|
||||
this.yarnConfig = yarnConfig;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AbstractLaunchCommand createLaunchCommand(Role role,
|
||||
Component component) throws IOException {
|
||||
Objects.requireNonNull(role, "Role must not be null!");
|
||||
|
||||
if (role == TensorFlowRole.WORKER ||
|
||||
role == TensorFlowRole.PRIMARY_WORKER) {
|
||||
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, role,
|
||||
component, parameters, yarnConfig);
|
||||
|
||||
} else if (role == TensorFlowRole.PS) {
|
||||
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, role, component,
|
||||
parameters, yarnConfig);
|
||||
|
||||
} else if (role == TensorFlowRole.TENSORBOARD) {
|
||||
return new TensorBoardLaunchCommand(hadoopEnvSetup, role, component,
|
||||
parameters);
|
||||
}
|
||||
throw new IllegalStateException("Unknown task type: " + role);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import org.apache.hadoop.yarn.service.api.records.Service;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* This class contains all the logic to create an instance
|
||||
* of a {@link Service} object for PyTorch.
|
||||
* Please note that currently, only single-node (non-distributed)
|
||||
* support is implemented for PyTorch.
|
||||
*/
|
||||
public class PyTorchServiceSpec extends AbstractServiceSpec {
|
||||
private static final Logger LOG =
|
||||
LoggerFactory.getLogger(PyTorchServiceSpec.class);
|
||||
//this field is needed in the future!
|
||||
private final PyTorchRunJobParameters pyTorchParameters;
|
||||
|
||||
public PyTorchServiceSpec(PyTorchRunJobParameters parameters,
|
||||
ClientContext clientContext, FileSystemOperations fsOperations,
|
||||
PyTorchLaunchCommandFactory launchCommandFactory, Localizer localizer) {
|
||||
super(parameters, clientContext, fsOperations, launchCommandFactory,
|
||||
localizer);
|
||||
this.pyTorchParameters = parameters;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ServiceWrapper create() throws IOException {
|
||||
LOG.info("Creating PyTorch service spec");
|
||||
ServiceWrapper serviceWrapper = createServiceSpecWrapper();
|
||||
|
||||
if (parameters.getNumWorkers() > 0) {
|
||||
addWorkerComponents(serviceWrapper, Framework.PYTORCH);
|
||||
}
|
||||
|
||||
// After all components added, handle quicklinks
|
||||
handleQuicklinks(serviceWrapper.getService());
|
||||
|
||||
return serviceWrapper;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Launch command implementation for PyTorch components.
|
||||
*/
|
||||
public class PyTorchWorkerLaunchCommand extends AbstractLaunchCommand {
|
||||
private static final Logger LOG =
|
||||
LoggerFactory.getLogger(PyTorchWorkerLaunchCommand.class);
|
||||
private final Configuration yarnConfig;
|
||||
private final boolean distributed;
|
||||
private final int numberOfWorkers;
|
||||
private final String name;
|
||||
private final Role role;
|
||||
private final String launchCommand;
|
||||
|
||||
public PyTorchWorkerLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||
Role role, Component component,
|
||||
PyTorchRunJobParameters parameters,
|
||||
Configuration yarnConfig) throws IOException {
|
||||
super(hadoopEnvSetup, component, parameters, role.getName());
|
||||
this.role = role;
|
||||
this.name = parameters.getName();
|
||||
this.distributed = parameters.isDistributed();
|
||||
this.numberOfWorkers = parameters.getNumWorkers();
|
||||
this.yarnConfig = yarnConfig;
|
||||
logReceivedParameters();
|
||||
|
||||
this.launchCommand = parameters.getWorkerLaunchCmd();
|
||||
|
||||
if (StringUtils.isEmpty(this.launchCommand)) {
|
||||
throw new IllegalArgumentException("LaunchCommand must not be null " +
|
||||
"or empty!");
|
||||
}
|
||||
}
|
||||
|
||||
private void logReceivedParameters() {
|
||||
if (this.numberOfWorkers <= 0) {
|
||||
LOG.warn("Received number of workers: {}", this.numberOfWorkers);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String generateLaunchScript() throws IOException {
|
||||
LaunchScriptBuilder builder = getBuilder();
|
||||
return builder
|
||||
.withLaunchCommand(createLaunchCommand())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String createLaunchCommand() {
|
||||
if (SubmarineLogs.isVerbose()) {
|
||||
LOG.info("PyTorch Worker command =[" + launchCommand + "]");
|
||||
}
|
||||
return launchCommand + '\n';
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
* This package contains classes to generate PyTorch launch commands.
|
||||
*/
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Component implementation for Worker process of PyTorch.
|
||||
*/
|
||||
public class PyTorchWorkerComponent extends AbstractComponent {
|
||||
public PyTorchWorkerComponent(FileSystemOperations fsOperations,
|
||||
RemoteDirectoryManager remoteDirectoryManager,
|
||||
PyTorchRunJobParameters parameters, Role role,
|
||||
PyTorchLaunchCommandFactory launchCommandFactory,
|
||||
Configuration yarnConfig) {
|
||||
super(fsOperations, remoteDirectoryManager, parameters, role,
|
||||
yarnConfig, launchCommandFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Component createComponent() throws IOException {
|
||||
return createComponentInternal();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
* This package contains classes to generate
|
||||
* PyTorch Native Service components.
|
||||
*/
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
* This package contains classes to generate
|
||||
* PyTorch-related Native Service runtime artifacts.
|
||||
*/
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;
|
|
@ -20,7 +20,7 @@ import org.apache.hadoop.conf.Configuration;
|
|||
import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.common.Envs;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
|
||||
|
||||
import java.util.Map;
|
||||
|
@ -35,10 +35,10 @@ public final class TensorFlowCommons {
|
|||
}
|
||||
|
||||
public static void addCommonEnvironments(Component component,
|
||||
TaskType taskType) {
|
||||
Role role) {
|
||||
Map<String, String> envs = component.getConfiguration().getEnv();
|
||||
envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
|
||||
envs.put(Envs.TASK_TYPE_ENV, taskType.name());
|
||||
envs.put(Envs.TASK_TYPE_ENV, role.getName());
|
||||
}
|
||||
|
||||
public static String getUserName() {
|
||||
|
@ -49,8 +49,8 @@ public final class TensorFlowCommons {
|
|||
return yarnConfig.get("hadoop.registry.dns.domain-name");
|
||||
}
|
||||
|
||||
public static String getScriptFileName(TaskType taskType) {
|
||||
return "run-" + taskType.name() + ".sh";
|
||||
public static String getScriptFileName(Role role) {
|
||||
return "run-" + role.getName() + ".sh";
|
||||
}
|
||||
|
||||
public static String getTFConfigEnv(String componentName, int nWorkers,
|
||||
|
|
|
@ -16,39 +16,24 @@
|
|||
|
||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
|
||||
import org.apache.hadoop.yarn.service.api.records.Service;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpec;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
|
||||
import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
|
||||
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
|
||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
|
||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
|
||||
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
|
||||
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
|
||||
|
||||
/**
|
||||
* This class contains all the logic to create an instance
|
||||
|
@ -56,42 +41,34 @@ import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handle
|
|||
* Worker,PS and Tensorboard components are added to the Service
|
||||
* based on the value of the received {@link RunJobParameters}.
|
||||
*/
|
||||
public class TensorFlowServiceSpec implements ServiceSpec {
|
||||
public class TensorFlowServiceSpec extends AbstractServiceSpec {
|
||||
private static final Logger LOG =
|
||||
LoggerFactory.getLogger(TensorFlowServiceSpec.class);
|
||||
private final TensorFlowRunJobParameters tensorFlowParameters;
|
||||
|
||||
private final RemoteDirectoryManager remoteDirectoryManager;
|
||||
|
||||
private final RunJobParameters parameters;
|
||||
private final Configuration yarnConfig;
|
||||
private final FileSystemOperations fsOperations;
|
||||
private final LaunchCommandFactory launchCommandFactory;
|
||||
private final Localizer localizer;
|
||||
|
||||
public TensorFlowServiceSpec(RunJobParameters parameters,
|
||||
public TensorFlowServiceSpec(TensorFlowRunJobParameters parameters,
|
||||
ClientContext clientContext, FileSystemOperations fsOperations,
|
||||
LaunchCommandFactory launchCommandFactory, Localizer localizer) {
|
||||
this.parameters = parameters;
|
||||
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
|
||||
this.yarnConfig = clientContext.getYarnConfig();
|
||||
this.fsOperations = fsOperations;
|
||||
this.launchCommandFactory = launchCommandFactory;
|
||||
this.localizer = localizer;
|
||||
TensorFlowLaunchCommandFactory launchCommandFactory,
|
||||
Localizer localizer) {
|
||||
super(parameters, clientContext, fsOperations, launchCommandFactory,
|
||||
localizer);
|
||||
this.tensorFlowParameters = parameters;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ServiceWrapper create() throws IOException {
|
||||
LOG.info("Creating TensorFlow service spec");
|
||||
ServiceWrapper serviceWrapper = createServiceSpecWrapper();
|
||||
|
||||
if (parameters.getNumWorkers() > 0) {
|
||||
addWorkerComponents(serviceWrapper);
|
||||
if (tensorFlowParameters.getNumWorkers() > 0) {
|
||||
addWorkerComponents(serviceWrapper, Framework.TENSORFLOW);
|
||||
}
|
||||
|
||||
if (parameters.getNumPS() > 0) {
|
||||
if (tensorFlowParameters.getNumPS() > 0) {
|
||||
addPsComponent(serviceWrapper);
|
||||
}
|
||||
|
||||
if (parameters.isTensorboardEnabled()) {
|
||||
if (tensorFlowParameters.isTensorboardEnabled()) {
|
||||
createTensorBoardComponent(serviceWrapper);
|
||||
}
|
||||
|
||||
|
@ -101,103 +78,23 @@ public class TensorFlowServiceSpec implements ServiceSpec {
|
|||
return serviceWrapper;
|
||||
}
|
||||
|
||||
private ServiceWrapper createServiceSpecWrapper() throws IOException {
|
||||
Service serviceSpec = new Service();
|
||||
serviceSpec.setName(parameters.getName());
|
||||
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
|
||||
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
|
||||
|
||||
KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
|
||||
.create(fsOperations, remoteDirectoryManager, parameters);
|
||||
if (kerberosPrincipal != null) {
|
||||
serviceSpec.setKerberosPrincipal(kerberosPrincipal);
|
||||
}
|
||||
|
||||
handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
|
||||
localizer.handleLocalizations(serviceSpec);
|
||||
return new ServiceWrapper(serviceSpec);
|
||||
}
|
||||
|
||||
private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
|
||||
throws IOException {
|
||||
TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
|
||||
remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
|
||||
remoteDirectoryManager, parameters,
|
||||
(TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
|
||||
serviceWrapper.addComponent(tbComponent);
|
||||
|
||||
addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
|
||||
tbComponent.getTensorboardLink());
|
||||
}
|
||||
|
||||
private static void addQuicklink(Service serviceSpec, String label,
|
||||
String link) {
|
||||
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
|
||||
if (quicklinks == null) {
|
||||
quicklinks = new HashMap<>();
|
||||
serviceSpec.setQuicklinks(quicklinks);
|
||||
}
|
||||
|
||||
if (SubmarineLogs.isVerbose()) {
|
||||
LOG.info("Added quicklink, " + label + "=" + link);
|
||||
}
|
||||
|
||||
quicklinks.put(label, link);
|
||||
}
|
||||
|
||||
private void handleQuicklinks(Service serviceSpec)
|
||||
throws IOException {
|
||||
List<Quicklink> quicklinks = parameters.getQuicklinks();
|
||||
if (quicklinks != null && !quicklinks.isEmpty()) {
|
||||
for (Quicklink ql : quicklinks) {
|
||||
// Make sure it is a valid instance name
|
||||
String instanceName = ql.getComponentInstanceName();
|
||||
boolean found = false;
|
||||
|
||||
for (Component comp : serviceSpec.getComponents()) {
|
||||
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
|
||||
String possibleInstanceName = comp.getName() + "-" + i;
|
||||
if (possibleInstanceName.equals(instanceName)) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
throw new IOException(
|
||||
"Couldn't find a component instance = " + instanceName
|
||||
+ " while adding quicklink");
|
||||
}
|
||||
|
||||
String link = ql.getProtocol()
|
||||
+ YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
|
||||
getUserName(), getDNSDomain(yarnConfig), ql.getPort());
|
||||
addQuicklink(serviceSpec, ql.getLabel(), link);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle worker and primary_worker.
|
||||
|
||||
private void addWorkerComponents(ServiceWrapper serviceWrapper)
|
||||
throws IOException {
|
||||
addWorkerComponent(serviceWrapper, parameters, TaskType.PRIMARY_WORKER);
|
||||
|
||||
if (parameters.getNumWorkers() > 1) {
|
||||
addWorkerComponent(serviceWrapper, parameters, TaskType.WORKER);
|
||||
}
|
||||
}
|
||||
private void addWorkerComponent(ServiceWrapper serviceWrapper,
|
||||
RunJobParameters parameters, TaskType taskType) throws IOException {
|
||||
serviceWrapper.addComponent(
|
||||
new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
|
||||
parameters, taskType, launchCommandFactory, yarnConfig));
|
||||
}
|
||||
|
||||
private void addPsComponent(ServiceWrapper serviceWrapper)
|
||||
throws IOException {
|
||||
serviceWrapper.addComponent(
|
||||
new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
|
||||
launchCommandFactory, parameters, yarnConfig));
|
||||
(TensorFlowLaunchCommandFactory) launchCommandFactory,
|
||||
parameters, yarnConfig));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
|
|||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -37,9 +37,9 @@ public class TensorBoardLaunchCommand extends AbstractLaunchCommand {
|
|||
private final String checkpointPath;
|
||||
|
||||
public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||
TaskType taskType, Component component, RunJobParameters parameters)
|
||||
Role role, Component component, RunJobParameters parameters)
|
||||
throws IOException {
|
||||
super(hadoopEnvSetup, taskType, component, parameters);
|
||||
super(hadoopEnvSetup, component, parameters, role.getName());
|
||||
Objects.requireNonNull(parameters.getCheckpointPath(),
|
||||
"CheckpointPath must not be null as it is part "
|
||||
+ "of the tensorboard command!");
|
||||
|
|
|
@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
|
|||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
|
||||
|
@ -28,6 +28,7 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Launch command implementation for
|
||||
|
@ -41,13 +42,16 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
|
|||
private final int numberOfWorkers;
|
||||
private final int numberOfPS;
|
||||
private final String name;
|
||||
private final TaskType taskType;
|
||||
private final Role role;
|
||||
|
||||
TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||
TaskType taskType, Component component, RunJobParameters parameters,
|
||||
Role role, Component component,
|
||||
TensorFlowRunJobParameters parameters,
|
||||
Configuration yarnConfig) throws IOException {
|
||||
super(hadoopEnvSetup, taskType, component, parameters);
|
||||
this.taskType = taskType;
|
||||
super(hadoopEnvSetup, component, parameters,
|
||||
role != null ? role.getName(): "");
|
||||
Objects.requireNonNull(role, "TensorFlowRole must not be null!");
|
||||
this.role = role;
|
||||
this.name = parameters.getName();
|
||||
this.distributed = parameters.isDistributed();
|
||||
this.numberOfWorkers = parameters.getNumWorkers();
|
||||
|
@ -72,7 +76,7 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
|
|||
// When distributed training is required
|
||||
if (distributed) {
|
||||
String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
|
||||
taskType.getComponentName(), numberOfWorkers,
|
||||
role.getComponentName(), numberOfWorkers,
|
||||
numberOfPS, name,
|
||||
TensorFlowCommons.getUserName(),
|
||||
TensorFlowCommons.getDNSDomain(yarnConfig));
|
||||
|
|
|
@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
|
|||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -37,9 +37,10 @@ public class TensorFlowPsLaunchCommand extends TensorFlowLaunchCommand {
|
|||
private final String launchCommand;
|
||||
|
||||
public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||
TaskType taskType, Component component, RunJobParameters parameters,
|
||||
Role role, Component component,
|
||||
TensorFlowRunJobParameters parameters,
|
||||
Configuration yarnConfig) throws IOException {
|
||||
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
|
||||
super(hadoopEnvSetup, role, component, parameters, yarnConfig);
|
||||
this.launchCommand = parameters.getPSLaunchCmd();
|
||||
|
||||
if (StringUtils.isEmpty(this.launchCommand)) {
|
||||
|
|
|
@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
|
|||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -37,10 +37,10 @@ public class TensorFlowWorkerLaunchCommand extends TensorFlowLaunchCommand {
|
|||
private final String launchCommand;
|
||||
|
||||
public TensorFlowWorkerLaunchCommand(
|
||||
HadoopEnvironmentSetup hadoopEnvSetup, TaskType taskType,
|
||||
Component component, RunJobParameters parameters,
|
||||
HadoopEnvironmentSetup hadoopEnvSetup, Role role,
|
||||
Component component, TensorFlowRunJobParameters parameters,
|
||||
Configuration yarnConfig) throws IOException {
|
||||
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
|
||||
super(hadoopEnvSetup, role, component, parameters, yarnConfig);
|
||||
this.launchCommand = parameters.getWorkerLaunchCmd();
|
||||
|
||||
if (StringUtils.isEmpty(this.launchCommand)) {
|
||||
|
|
|
@ -19,13 +19,14 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.compone
|
|||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||
import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -54,35 +55,38 @@ public class TensorBoardComponent extends AbstractComponent {
|
|||
public TensorBoardComponent(FileSystemOperations fsOperations,
|
||||
RemoteDirectoryManager remoteDirectoryManager,
|
||||
RunJobParameters parameters,
|
||||
LaunchCommandFactory launchCommandFactory,
|
||||
TensorFlowLaunchCommandFactory launchCommandFactory,
|
||||
Configuration yarnConfig) {
|
||||
super(fsOperations, remoteDirectoryManager, parameters,
|
||||
TaskType.TENSORBOARD, yarnConfig, launchCommandFactory);
|
||||
TensorFlowRole.TENSORBOARD, yarnConfig, launchCommandFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Component createComponent() throws IOException {
|
||||
Objects.requireNonNull(parameters.getTensorboardResource(),
|
||||
TensorFlowRunJobParameters tensorFlowParams =
|
||||
(TensorFlowRunJobParameters) this.parameters;
|
||||
|
||||
Objects.requireNonNull(tensorFlowParams.getTensorboardResource(),
|
||||
"TensorBoard resource must not be null!");
|
||||
|
||||
Component component = new Component();
|
||||
component.setName(taskType.getComponentName());
|
||||
component.setName(role.getComponentName());
|
||||
component.setNumberOfContainers(1L);
|
||||
component.setRestartPolicy(RestartPolicyEnum.NEVER);
|
||||
component.setResource(convertYarnResourceToServiceResource(
|
||||
parameters.getTensorboardResource()));
|
||||
tensorFlowParams.getTensorboardResource()));
|
||||
|
||||
if (parameters.getTensorboardDockerImage() != null) {
|
||||
if (tensorFlowParams.getTensorboardDockerImage() != null) {
|
||||
component.setArtifact(
|
||||
getDockerArtifact(parameters.getTensorboardDockerImage()));
|
||||
getDockerArtifact(tensorFlowParams.getTensorboardDockerImage()));
|
||||
}
|
||||
|
||||
addCommonEnvironments(component, taskType);
|
||||
addCommonEnvironments(component, role);
|
||||
generateLaunchCommand(component);
|
||||
|
||||
tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
|
||||
parameters.getName(),
|
||||
taskType.getComponentName() + "-" + 0, getUserName(),
|
||||
role.getComponentName() + "-" + 0, getUserName(),
|
||||
getDNSDomain(yarnConfig), DEFAULT_PORT);
|
||||
LOG.info("Link to tensorboard:" + tensorboardLink);
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue