SUBMARINE-52. [SUBMARINE-14] Generate Service spec + launch script for single-node PyTorch learning job. Contributed by Szilard Nemeth.
This commit is contained in:
parent
64c7f36ab1
commit
36267b6f7c
|
@ -0,0 +1,77 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
# contributor license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright ownership.
|
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
# (the "License"); you may not use this file except in compliance with
|
||||||
|
# the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
|
||||||
|
ARG PYTHON_VERSION=3.6
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
cmake \
|
||||||
|
git \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
ca-certificates \
|
||||||
|
libjpeg-dev \
|
||||||
|
libpng-dev \
|
||||||
|
wget &&\
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
||||||
|
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||||
|
chmod +x ~/miniconda.sh && \
|
||||||
|
~/miniconda.sh -b -p /opt/conda && \
|
||||||
|
rm ~/miniconda.sh && \
|
||||||
|
/opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include cython typing && \
|
||||||
|
/opt/conda/bin/conda install -y -c pytorch magma-cuda100 && \
|
||||||
|
/opt/conda/bin/conda clean -ya
|
||||||
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
RUN pip install ninja
|
||||||
|
# This must be done before pip so that requirements.txt is available
|
||||||
|
WORKDIR /opt/pytorch
|
||||||
|
RUN git clone https://github.com/pytorch/pytorch.git
|
||||||
|
WORKDIR pytorch
|
||||||
|
RUN git submodule update --init
|
||||||
|
RUN TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1 7.0+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
|
||||||
|
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
|
||||||
|
pip install -v .
|
||||||
|
|
||||||
|
WORKDIR /opt/pytorch
|
||||||
|
RUN git clone https://github.com/pytorch/vision.git && cd vision && pip install -v .
|
||||||
|
|
||||||
|
WORKDIR /
|
||||||
|
# Install Hadoop
|
||||||
|
ENV HADOOP_VERSION="3.1.2"
|
||||||
|
RUN wget https://archive.apache.org/dist/hadoop/common/hadoop-${HADOOP_VERSION}/hadoop-${HADOOP_VERSION}.tar.gz
|
||||||
|
RUN tar zxf hadoop-${HADOOP_VERSION}.tar.gz
|
||||||
|
RUN ln -s hadoop-${HADOOP_VERSION} hadoop-current
|
||||||
|
RUN rm hadoop-${HADOOP_VERSION}.tar.gz
|
||||||
|
|
||||||
|
ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64
|
||||||
|
RUN echo "$LOG_TAG Install java8" && \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends openjdk-8-jdk && \
|
||||||
|
apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN echo "Install python related packages" && \
|
||||||
|
pip --no-cache-dir install Pillow h5py ipykernel jupyter matplotlib numpy pandas scipy sklearn && \
|
||||||
|
python -m ipykernel.kernelspec
|
||||||
|
|
||||||
|
# Set the locale to fix bash warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends locales && \
|
||||||
|
apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
RUN locale-gen en_US.UTF-8
|
||||||
|
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
RUN chmod -R a+w /workspace
|
|
@ -0,0 +1,30 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
# contributor license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright ownership.
|
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
# (the "License"); you may not use this file except in compliance with
|
||||||
|
# the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
echo "Building base images"
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
cd base/ubuntu-16.04
|
||||||
|
|
||||||
|
docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu-base:0.0.1
|
||||||
|
|
||||||
|
echo "Finished building base images"
|
||||||
|
|
||||||
|
cd ../../with-cifar10-models/ubuntu-16.04
|
||||||
|
|
||||||
|
docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu:0.0.1
|
|
@ -0,0 +1,354 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Training a Classifier
|
||||||
|
=====================
|
||||||
|
|
||||||
|
This is it. You have seen how to define neural networks, compute loss and make
|
||||||
|
updates to the weights of the network.
|
||||||
|
|
||||||
|
Now you might be thinking,
|
||||||
|
|
||||||
|
What about data?
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Generally, when you have to deal with image, text, audio or video data,
|
||||||
|
you can use standard python packages that load data into a numpy array.
|
||||||
|
Then you can convert this array into a ``torch.*Tensor``.
|
||||||
|
|
||||||
|
- For images, packages such as Pillow, OpenCV are useful
|
||||||
|
- For audio, packages such as scipy and librosa
|
||||||
|
- For text, either raw Python or Cython based loading, or NLTK and
|
||||||
|
SpaCy are useful
|
||||||
|
|
||||||
|
Specifically for vision, we have created a package called
|
||||||
|
``torchvision``, that has data loaders for common datasets such as
|
||||||
|
Imagenet, CIFAR10, MNIST, etc. and data transformers for images, viz.,
|
||||||
|
``torchvision.datasets`` and ``torch.utils.data.DataLoader``.
|
||||||
|
|
||||||
|
This provides a huge convenience and avoids writing boilerplate code.
|
||||||
|
|
||||||
|
For this tutorial, we will use the CIFAR10 dataset.
|
||||||
|
It has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,
|
||||||
|
‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. The images in CIFAR-10 are of
|
||||||
|
size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.
|
||||||
|
|
||||||
|
.. figure:: /_static/img/cifar10.png
|
||||||
|
:alt: cifar10
|
||||||
|
|
||||||
|
cifar10
|
||||||
|
|
||||||
|
|
||||||
|
Training an image classifier
|
||||||
|
----------------------------
|
||||||
|
|
||||||
|
We will do the following steps in order:
|
||||||
|
|
||||||
|
1. Load and normalizing the CIFAR10 training and test datasets using
|
||||||
|
``torchvision``
|
||||||
|
2. Define a Convolutional Neural Network
|
||||||
|
3. Define a loss function
|
||||||
|
4. Train the network on the training data
|
||||||
|
5. Test the network on the test data
|
||||||
|
|
||||||
|
1. Loading and normalizing CIFAR10
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Using ``torchvision``, it’s extremely easy to load CIFAR10.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# The output of torchvision datasets are PILImage images of range [0, 1].
|
||||||
|
# We transform them to Tensors of normalized range [-1, 1].
|
||||||
|
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
|
||||||
|
download=True, transform=transform)
|
||||||
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
|
||||||
|
shuffle=True, num_workers=2)
|
||||||
|
|
||||||
|
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
|
||||||
|
download=True, transform=transform)
|
||||||
|
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
|
||||||
|
shuffle=False, num_workers=2)
|
||||||
|
|
||||||
|
classes = ('plane', 'car', 'bird', 'cat',
|
||||||
|
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# Let us show some of the training images, for fun.
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# functions to show an image
|
||||||
|
|
||||||
|
|
||||||
|
def imshow(img):
|
||||||
|
img = img / 2 + 0.5 # unnormalize
|
||||||
|
npimg = img.numpy()
|
||||||
|
plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# get some random training images
|
||||||
|
dataiter = iter(trainloader)
|
||||||
|
images, labels = dataiter.next()
|
||||||
|
|
||||||
|
# show images
|
||||||
|
imshow(torchvision.utils.make_grid(images))
|
||||||
|
# print labels
|
||||||
|
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# 2. Define a Convolutional Neural Network
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# Copy the neural network from the Neural Networks section before and modify it to
|
||||||
|
# take 3-channel images (instead of 1-channel images as it was defined).
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
||||||
|
self.pool = nn.MaxPool2d(2, 2)
|
||||||
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||||
|
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
||||||
|
self.fc2 = nn.Linear(120, 84)
|
||||||
|
self.fc3 = nn.Linear(84, 10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pool(F.relu(self.conv1(x)))
|
||||||
|
x = self.pool(F.relu(self.conv2(x)))
|
||||||
|
x = x.view(-1, 16 * 5 * 5)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = F.relu(self.fc2(x))
|
||||||
|
x = self.fc3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# 3. Define a Loss function and optimizer
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# Let's use a Classification Cross-Entropy loss and SGD with momentum.
|
||||||
|
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# 4. Train the network
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^
|
||||||
|
#
|
||||||
|
# This is when things start to get interesting.
|
||||||
|
# We simply have to loop over our data iterator, and feed the inputs to the
|
||||||
|
# network and optimize.
|
||||||
|
|
||||||
|
for epoch in range(2): # loop over the dataset multiple times
|
||||||
|
|
||||||
|
running_loss = 0.0
|
||||||
|
for i, data in enumerate(trainloader, 0):
|
||||||
|
# get the inputs
|
||||||
|
inputs, labels = data
|
||||||
|
|
||||||
|
# zero the parameter gradients
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# forward + backward + optimize
|
||||||
|
outputs = net(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# print statistics
|
||||||
|
running_loss += loss.item()
|
||||||
|
if i % 2000 == 1999: # print every 2000 mini-batches
|
||||||
|
print('[%d, %5d] loss: %.3f' %
|
||||||
|
(epoch + 1, i + 1, running_loss / 2000))
|
||||||
|
running_loss = 0.0
|
||||||
|
|
||||||
|
print('Finished Training')
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# 5. Test the network on the test data
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
#
|
||||||
|
# We have trained the network for 2 passes over the training dataset.
|
||||||
|
# But we need to check if the network has learnt anything at all.
|
||||||
|
#
|
||||||
|
# We will check this by predicting the class label that the neural network
|
||||||
|
# outputs, and checking it against the ground-truth. If the prediction is
|
||||||
|
# correct, we add the sample to the list of correct predictions.
|
||||||
|
#
|
||||||
|
# Okay, first step. Let us display an image from the test set to get familiar.
|
||||||
|
|
||||||
|
dataiter = iter(testloader)
|
||||||
|
images, labels = dataiter.next()
|
||||||
|
|
||||||
|
# print images
|
||||||
|
imshow(torchvision.utils.make_grid(images))
|
||||||
|
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# Okay, now let us see what the neural network thinks these examples above are:
|
||||||
|
|
||||||
|
outputs = net(images)
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# The outputs are energies for the 10 classes.
|
||||||
|
# The higher the energy for a class, the more the network
|
||||||
|
# thinks that the image is of the particular class.
|
||||||
|
# So, let's get the index of the highest energy:
|
||||||
|
_, predicted = torch.max(outputs, 1)
|
||||||
|
|
||||||
|
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
|
||||||
|
for j in range(4)))
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# The results seem pretty good.
|
||||||
|
#
|
||||||
|
# Let us look at how the network performs on the whole dataset.
|
||||||
|
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in testloader:
|
||||||
|
images, labels = data
|
||||||
|
outputs = net(images)
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
print('Accuracy of the network on the 10000 test images: %d %%' % (
|
||||||
|
100 * correct / total))
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# That looks waaay better than chance, which is 10% accuracy (randomly picking
|
||||||
|
# a class out of 10 classes).
|
||||||
|
# Seems like the network learnt something.
|
||||||
|
#
|
||||||
|
# Hmmm, what are the classes that performed well, and the classes that did
|
||||||
|
# not perform well:
|
||||||
|
|
||||||
|
class_correct = list(0. for i in range(10))
|
||||||
|
class_total = list(0. for i in range(10))
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in testloader:
|
||||||
|
images, labels = data
|
||||||
|
outputs = net(images)
|
||||||
|
_, predicted = torch.max(outputs, 1)
|
||||||
|
c = (predicted == labels).squeeze()
|
||||||
|
for i in range(4):
|
||||||
|
label = labels[i]
|
||||||
|
class_correct[label] += c[i].item()
|
||||||
|
class_total[label] += 1
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
print('Accuracy of %5s : %2d %%' % (
|
||||||
|
classes[i], 100 * class_correct[i] / class_total[i]))
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# Okay, so what next?
|
||||||
|
#
|
||||||
|
# How do we run these neural networks on the GPU?
|
||||||
|
#
|
||||||
|
# Training on GPU
|
||||||
|
# ----------------
|
||||||
|
# Just like how you transfer a Tensor onto the GPU, you transfer the neural
|
||||||
|
# net onto the GPU.
|
||||||
|
#
|
||||||
|
# Let's first define our device as the first visible cuda device if we have
|
||||||
|
# CUDA available:
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Assuming that we are on a CUDA machine, this should print a CUDA device:
|
||||||
|
|
||||||
|
print(device)
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# The rest of this section assumes that ``device`` is a CUDA device.
|
||||||
|
#
|
||||||
|
# Then these methods will recursively go over all modules and convert their
|
||||||
|
# parameters and buffers to CUDA tensors:
|
||||||
|
#
|
||||||
|
# .. code:: python
|
||||||
|
#
|
||||||
|
# net.to(device)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Remember that you will have to send the inputs and targets at every step
|
||||||
|
# to the GPU too:
|
||||||
|
#
|
||||||
|
# .. code:: python
|
||||||
|
#
|
||||||
|
# inputs, labels = inputs.to(device), labels.to(device)
|
||||||
|
#
|
||||||
|
# Why dont I notice MASSIVE speedup compared to CPU? Because your network
|
||||||
|
# is realllly small.
|
||||||
|
#
|
||||||
|
# **Exercise:** Try increasing the width of your network (argument 2 of
|
||||||
|
# the first ``nn.Conv2d``, and argument 1 of the second ``nn.Conv2d`` –
|
||||||
|
# they need to be the same number), see what kind of speedup you get.
|
||||||
|
#
|
||||||
|
# **Goals achieved**:
|
||||||
|
#
|
||||||
|
# - Understanding PyTorch's Tensor library and neural networks at a high level.
|
||||||
|
# - Train a small neural network to classify images
|
||||||
|
#
|
||||||
|
# Training on multiple GPUs
|
||||||
|
# -------------------------
|
||||||
|
# If you want to see even more MASSIVE speedup using all of your GPUs,
|
||||||
|
# please check out :doc:`data_parallel_tutorial`.
|
||||||
|
#
|
||||||
|
# Where do I go next?
|
||||||
|
# -------------------
|
||||||
|
#
|
||||||
|
# - :doc:`Train neural nets to play video games </intermediate/reinforcement_q_learning>`
|
||||||
|
# - `Train a state-of-the-art ResNet network on imagenet`_
|
||||||
|
# - `Train a face generator using Generative Adversarial Networks`_
|
||||||
|
# - `Train a word-level language model using Recurrent LSTM networks`_
|
||||||
|
# - `More examples`_
|
||||||
|
# - `More tutorials`_
|
||||||
|
# - `Discuss PyTorch on the Forums`_
|
||||||
|
# - `Chat with other users on Slack`_
|
||||||
|
#
|
||||||
|
# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/master/imagenet
|
||||||
|
# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/master/dcgan
|
||||||
|
# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/master/word_language_model
|
||||||
|
# .. _More examples: https://github.com/pytorch/examples
|
||||||
|
# .. _More tutorials: https://github.com/pytorch/tutorials
|
||||||
|
# .. _Discuss PyTorch on the Forums: https://discuss.pytorch.org/
|
||||||
|
# .. _Chat with other users on Slack: https://pytorch.slack.com/messages/beginner/
|
||||||
|
|
||||||
|
# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%
|
||||||
|
del dataiter
|
||||||
|
# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
FROM pytorch-latest-gpu-base:0.0.1
|
||||||
|
|
||||||
|
RUN mkdir -p /test/data
|
||||||
|
RUN chmod -R 777 /test
|
||||||
|
ADD cifar10_tutorial.py /test/cifar10_tutorial.py
|
|
@ -15,6 +15,7 @@
|
||||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||||
|
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
import org.apache.hadoop.yarn.conf.YarnConfiguration;
|
import org.apache.hadoop.yarn.conf.YarnConfiguration;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
||||||
|
|
|
@ -59,4 +59,6 @@ public class CliConstants {
|
||||||
public static final String DISTRIBUTE_KEYTAB = "distribute_keytab";
|
public static final String DISTRIBUTE_KEYTAB = "distribute_keytab";
|
||||||
public static final String YAML_CONFIG = "f";
|
public static final String YAML_CONFIG = "f";
|
||||||
public static final String INSECURE_CLUSTER = "insecure";
|
public static final String INSECURE_CLUSTER = "insecure";
|
||||||
|
|
||||||
|
public static final String FRAMEWORK = "framework";
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ package org.apache.hadoop.yarn.submarine.client.cli;
|
||||||
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.hadoop.security.UserGroupInformation;
|
import org.apache.hadoop.security.UserGroupInformation;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException;
|
import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException;
|
||||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a Submarine command.
|
||||||
|
*/
|
||||||
|
public enum Command {
|
||||||
|
RUN_JOB, SHOW_JOB
|
||||||
|
}
|
|
@ -37,7 +37,7 @@ public class ShowJobCli extends AbstractCli {
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(ShowJobCli.class);
|
private static final Logger LOG = LoggerFactory.getLogger(ShowJobCli.class);
|
||||||
|
|
||||||
private Options options;
|
private Options options;
|
||||||
private ShowJobParameters parameters = new ShowJobParameters();
|
private ParametersHolder parametersHolder;
|
||||||
|
|
||||||
public ShowJobCli(ClientContext cliContext) {
|
public ShowJobCli(ClientContext cliContext) {
|
||||||
super(cliContext);
|
super(cliContext);
|
||||||
|
@ -62,9 +62,9 @@ public class ShowJobCli extends AbstractCli {
|
||||||
CommandLine cli;
|
CommandLine cli;
|
||||||
try {
|
try {
|
||||||
cli = parser.parse(options, args);
|
cli = parser.parse(options, args);
|
||||||
ParametersHolder parametersHolder = ParametersHolder
|
parametersHolder = ParametersHolder
|
||||||
.createWithCmdLine(cli);
|
.createWithCmdLine(cli, Command.SHOW_JOB);
|
||||||
parameters.updateParameters(parametersHolder, clientContext);
|
parametersHolder.updateParameters(clientContext);
|
||||||
} catch (ParseException e) {
|
} catch (ParseException e) {
|
||||||
printUsages();
|
printUsages();
|
||||||
}
|
}
|
||||||
|
@ -97,7 +97,7 @@ public class ShowJobCli extends AbstractCli {
|
||||||
|
|
||||||
Map<String, String> jobInfo = null;
|
Map<String, String> jobInfo = null;
|
||||||
try {
|
try {
|
||||||
jobInfo = storage.getJobInfoByName(parameters.getName());
|
jobInfo = storage.getJobInfoByName(getParameters().getName());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
LOG.error("Failed to retrieve job info", e);
|
LOG.error("Failed to retrieve job info", e);
|
||||||
throw e;
|
throw e;
|
||||||
|
@ -108,7 +108,7 @@ public class ShowJobCli extends AbstractCli {
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
public ShowJobParameters getParameters() {
|
public ShowJobParameters getParameters() {
|
||||||
return parameters;
|
return (ShowJobParameters) parametersHolder.getParameters();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.param;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the source of configuration.
|
||||||
|
*/
|
||||||
|
public enum ConfigType {
|
||||||
|
YAML, CLI
|
||||||
|
}
|
|
@ -20,8 +20,12 @@ import com.google.common.collect.ImmutableSet;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Maps;
|
import com.google.common.collect.Maps;
|
||||||
import org.apache.commons.cli.CommandLine;
|
import org.apache.commons.cli.CommandLine;
|
||||||
|
import org.apache.commons.cli.ParseException;
|
||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.Command;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Roles;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Roles;
|
||||||
|
@ -29,15 +33,22 @@ import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Scheduling;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Security;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Security;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.TensorBoard;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.TensorBoard;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli.YAML_PARSE_FAILED;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class acts as a wrapper of {@code CommandLine} values along with
|
* This class acts as a wrapper of {@code CommandLine} values along with
|
||||||
* YAML configuration values.
|
* YAML configuration values.
|
||||||
|
@ -52,17 +63,110 @@ public final class ParametersHolder {
|
||||||
private static final Logger LOG =
|
private static final Logger LOG =
|
||||||
LoggerFactory.getLogger(ParametersHolder.class);
|
LoggerFactory.getLogger(ParametersHolder.class);
|
||||||
|
|
||||||
|
public static final String SUPPORTED_FRAMEWORKS_MESSAGE =
|
||||||
|
"TensorFlow and PyTorch are the only supported frameworks for now!";
|
||||||
|
public static final String SUPPORTED_COMMANDS_MESSAGE =
|
||||||
|
"'Show job' and 'run job' are the only supported commands for now!";
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private final CommandLine parsedCommandLine;
|
private final CommandLine parsedCommandLine;
|
||||||
private final Map<String, String> yamlStringConfigs;
|
private final Map<String, String> yamlStringConfigs;
|
||||||
private final Map<String, List<String>> yamlListConfigs;
|
private final Map<String, List<String>> yamlListConfigs;
|
||||||
private final ImmutableSet onlyDefinedWithCliArgs = ImmutableSet.of(
|
private final ConfigType configType;
|
||||||
|
private Command command;
|
||||||
|
private final Set onlyDefinedWithCliArgs = ImmutableSet.of(
|
||||||
CliConstants.VERBOSE);
|
CliConstants.VERBOSE);
|
||||||
|
private final Framework framework;
|
||||||
|
private final BaseParameters parameters;
|
||||||
|
|
||||||
private ParametersHolder(CommandLine parsedCommandLine,
|
private ParametersHolder(CommandLine parsedCommandLine,
|
||||||
YamlConfigFile yamlConfig) {
|
YamlConfigFile yamlConfig, ConfigType configType, Command command)
|
||||||
|
throws ParseException, YarnException {
|
||||||
this.parsedCommandLine = parsedCommandLine;
|
this.parsedCommandLine = parsedCommandLine;
|
||||||
this.yamlStringConfigs = initStringConfigValues(yamlConfig);
|
this.yamlStringConfigs = initStringConfigValues(yamlConfig);
|
||||||
this.yamlListConfigs = initListConfigValues(yamlConfig);
|
this.yamlListConfigs = initListConfigValues(yamlConfig);
|
||||||
|
this.configType = configType;
|
||||||
|
this.command = command;
|
||||||
|
this.framework = determineFrameworkType();
|
||||||
|
this.ensureOnlyValidSectionsAreDefined(yamlConfig);
|
||||||
|
this.parameters = createParameters();
|
||||||
|
}
|
||||||
|
|
||||||
|
private BaseParameters createParameters() {
|
||||||
|
if (command == Command.RUN_JOB) {
|
||||||
|
if (framework == Framework.TENSORFLOW) {
|
||||||
|
return new TensorFlowRunJobParameters();
|
||||||
|
} else if (framework == Framework.PYTORCH) {
|
||||||
|
return new PyTorchRunJobParameters();
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
|
||||||
|
}
|
||||||
|
} else if (command == Command.SHOW_JOB) {
|
||||||
|
return new ShowJobParameters();
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException(SUPPORTED_COMMANDS_MESSAGE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void ensureOnlyValidSectionsAreDefined(YamlConfigFile yamlConfig) {
|
||||||
|
if (isCommandRunJob() && isFrameworkPyTorch() &&
|
||||||
|
isPsSectionDefined(yamlConfig)) {
|
||||||
|
throw new YamlParseException(
|
||||||
|
"PS section should not be defined when PyTorch " +
|
||||||
|
"is the selected framework!");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isCommandRunJob() && isFrameworkPyTorch() &&
|
||||||
|
isTensorboardSectionDefined(yamlConfig)) {
|
||||||
|
throw new YamlParseException(
|
||||||
|
"TensorBoard section should not be defined when PyTorch " +
|
||||||
|
"is the selected framework!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isCommandRunJob() {
|
||||||
|
return command == Command.RUN_JOB;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isFrameworkPyTorch() {
|
||||||
|
return framework == Framework.PYTORCH;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isPsSectionDefined(YamlConfigFile yamlConfig) {
|
||||||
|
return yamlConfig != null &&
|
||||||
|
yamlConfig.getRoles() != null &&
|
||||||
|
yamlConfig.getRoles().getPs() != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isTensorboardSectionDefined(YamlConfigFile yamlConfig) {
|
||||||
|
return yamlConfig != null &&
|
||||||
|
yamlConfig.getTensorBoard() != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Framework determineFrameworkType()
|
||||||
|
throws ParseException, YarnException {
|
||||||
|
if (!isCommandRunJob()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
String frameworkStr = getOptionValue(CliConstants.FRAMEWORK);
|
||||||
|
if (frameworkStr == null) {
|
||||||
|
LOG.info("Framework is not defined in config, falling back to " +
|
||||||
|
"TensorFlow as a default.");
|
||||||
|
return Framework.TENSORFLOW;
|
||||||
|
}
|
||||||
|
Framework framework = Framework.parseByValue(frameworkStr);
|
||||||
|
if (framework == null) {
|
||||||
|
if (getConfigType() == ConfigType.CLI) {
|
||||||
|
throw new ParseException("Failed to parse Framework type! "
|
||||||
|
+ "Valid values are: " + Framework.getValues());
|
||||||
|
} else {
|
||||||
|
throw new YamlParseException(YAML_PARSE_FAILED +
|
||||||
|
", framework should is defined, but it has an invalid value! " +
|
||||||
|
"Valid values are: " + Framework.getValues());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return framework;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -108,6 +212,8 @@ public final class ParametersHolder {
|
||||||
private void initGenericConfigs(YamlConfigFile yamlConfig,
|
private void initGenericConfigs(YamlConfigFile yamlConfig,
|
||||||
Map<String, String> yamlConfigs) {
|
Map<String, String> yamlConfigs) {
|
||||||
yamlConfigs.put(CliConstants.NAME, yamlConfig.getSpec().getName());
|
yamlConfigs.put(CliConstants.NAME, yamlConfig.getSpec().getName());
|
||||||
|
yamlConfigs.put(CliConstants.FRAMEWORK,
|
||||||
|
yamlConfig.getSpec().getFramework());
|
||||||
|
|
||||||
Configs configs = yamlConfig.getConfigs();
|
Configs configs = yamlConfig.getConfigs();
|
||||||
yamlConfigs.put(CliConstants.INPUT_PATH, configs.getInputPath());
|
yamlConfigs.put(CliConstants.INPUT_PATH, configs.getInputPath());
|
||||||
|
@ -178,13 +284,15 @@ public final class ParametersHolder {
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ParametersHolder createWithCmdLine(CommandLine cli) {
|
public static ParametersHolder createWithCmdLine(CommandLine cli,
|
||||||
return new ParametersHolder(cli, null);
|
Command command) throws ParseException, YarnException {
|
||||||
|
return new ParametersHolder(cli, null, ConfigType.CLI, command);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ParametersHolder createWithCmdLineAndYaml(CommandLine cli,
|
public static ParametersHolder createWithCmdLineAndYaml(CommandLine cli,
|
||||||
YamlConfigFile yamlConfig) {
|
YamlConfigFile yamlConfig, Command command) throws ParseException,
|
||||||
return new ParametersHolder(cli, yamlConfig);
|
YarnException {
|
||||||
|
return new ParametersHolder(cli, yamlConfig, ConfigType.YAML, command);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -193,7 +301,7 @@ public final class ParametersHolder {
|
||||||
* @param option Name of the config.
|
* @param option Name of the config.
|
||||||
* @return The value of the config
|
* @return The value of the config
|
||||||
*/
|
*/
|
||||||
String getOptionValue(String option) throws YarnException {
|
public String getOptionValue(String option) throws YarnException {
|
||||||
ensureConfigIsDefinedOnce(option, true);
|
ensureConfigIsDefinedOnce(option, true);
|
||||||
if (onlyDefinedWithCliArgs.contains(option) ||
|
if (onlyDefinedWithCliArgs.contains(option) ||
|
||||||
parsedCommandLine.hasOption(option)) {
|
parsedCommandLine.hasOption(option)) {
|
||||||
|
@ -208,7 +316,7 @@ public final class ParametersHolder {
|
||||||
* @param option Name of the config.
|
* @param option Name of the config.
|
||||||
* @return The values of the config
|
* @return The values of the config
|
||||||
*/
|
*/
|
||||||
List<String> getOptionValues(String option) throws YarnException {
|
public List<String> getOptionValues(String option) throws YarnException {
|
||||||
ensureConfigIsDefinedOnce(option, false);
|
ensureConfigIsDefinedOnce(option, false);
|
||||||
if (onlyDefinedWithCliArgs.contains(option) ||
|
if (onlyDefinedWithCliArgs.contains(option) ||
|
||||||
parsedCommandLine.hasOption(option)) {
|
parsedCommandLine.hasOption(option)) {
|
||||||
|
@ -285,7 +393,7 @@ public final class ParametersHolder {
|
||||||
* @return true, if the option is found in the CLI args or in the YAML config,
|
* @return true, if the option is found in the CLI args or in the YAML config,
|
||||||
* false otherwise.
|
* false otherwise.
|
||||||
*/
|
*/
|
||||||
boolean hasOption(String option) {
|
public boolean hasOption(String option) {
|
||||||
if (onlyDefinedWithCliArgs.contains(option)) {
|
if (onlyDefinedWithCliArgs.contains(option)) {
|
||||||
boolean value = parsedCommandLine.hasOption(option);
|
boolean value = parsedCommandLine.hasOption(option);
|
||||||
if (LOG.isDebugEnabled()) {
|
if (LOG.isDebugEnabled()) {
|
||||||
|
@ -312,4 +420,21 @@ public final class ParametersHolder {
|
||||||
"from YAML configuration.", result, option);
|
"from YAML configuration.", result, option);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ConfigType getConfigType() {
|
||||||
|
return configType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Framework getFramework() {
|
||||||
|
return framework;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateParameters(ClientContext clientContext)
|
||||||
|
throws ParseException, YarnException, IOException {
|
||||||
|
parameters.updateParameters(this, clientContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
public BaseParameters getParameters() {
|
||||||
|
return parameters;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.commons.cli.ParseException;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameters for PyTorch job.
|
||||||
|
*/
|
||||||
|
public class PyTorchRunJobParameters extends RunJobParameters {
|
||||||
|
|
||||||
|
private static final String CANNOT_BE_DEFINED_FOR_PYTORCH =
|
||||||
|
"cannot be defined for PyTorch jobs!";
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void updateParameters(ParametersHolder parametersHolder,
|
||||||
|
ClientContext clientContext)
|
||||||
|
throws ParseException, IOException, YarnException {
|
||||||
|
checkArguments(parametersHolder);
|
||||||
|
|
||||||
|
super.updateParameters(parametersHolder, clientContext);
|
||||||
|
|
||||||
|
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
|
||||||
|
this.workerParameters =
|
||||||
|
getWorkerParameters(clientContext, parametersHolder, input);
|
||||||
|
this.distributed = determineIfDistributed(workerParameters.getReplicas());
|
||||||
|
executePostOperations(clientContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkArguments(ParametersHolder parametersHolder)
|
||||||
|
throws YarnException, ParseException {
|
||||||
|
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
|
||||||
|
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||||
|
CliConstants.N_PS));
|
||||||
|
} else if (parametersHolder.getOptionValue(CliConstants.PS_RES) != null) {
|
||||||
|
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||||
|
CliConstants.PS_RES));
|
||||||
|
} else if (parametersHolder
|
||||||
|
.getOptionValue(CliConstants.PS_DOCKER_IMAGE) != null) {
|
||||||
|
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||||
|
CliConstants.PS_DOCKER_IMAGE));
|
||||||
|
} else if (parametersHolder
|
||||||
|
.getOptionValue(CliConstants.PS_LAUNCH_CMD) != null) {
|
||||||
|
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||||
|
CliConstants.PS_LAUNCH_CMD));
|
||||||
|
} else if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
|
||||||
|
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||||
|
CliConstants.TENSORBOARD));
|
||||||
|
} else if (parametersHolder
|
||||||
|
.getOptionValue(CliConstants.TENSORBOARD_RESOURCES) != null) {
|
||||||
|
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||||
|
CliConstants.TENSORBOARD_RESOURCES));
|
||||||
|
} else if (parametersHolder
|
||||||
|
.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE) != null) {
|
||||||
|
throw new ParseException(getParamCannotBeDefinedErrorMessage(
|
||||||
|
CliConstants.TENSORBOARD_DOCKER_IMAGE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getParamCannotBeDefinedErrorMessage(String cliName) {
|
||||||
|
return String.format(
|
||||||
|
"Parameter '%s' " + CANNOT_BE_DEFINED_FOR_PYTORCH, cliName);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
void executePostOperations(ClientContext clientContext) throws IOException {
|
||||||
|
// Set default job dir / saved model dir, etc.
|
||||||
|
setDefaultDirs(clientContext);
|
||||||
|
replacePatternsInParameters(clientContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void replacePatternsInParameters(ClientContext clientContext)
|
||||||
|
throws IOException {
|
||||||
|
if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
|
||||||
|
String afterReplace =
|
||||||
|
CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
|
||||||
|
clientContext.getRemoteDirectoryManager());
|
||||||
|
setWorkerLaunchCmd(afterReplace);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getLaunchCommands() {
|
||||||
|
return Lists.newArrayList(getWorkerLaunchCmd());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We only support non-distributed PyTorch integration for now.
|
||||||
|
* @param nWorkers
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private boolean determineIfDistributed(int nWorkers) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,19 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -12,7 +28,7 @@
|
||||||
* limitations under the License. See accompanying LICENSE file.
|
* limitations under the License. See accompanying LICENSE file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.submarine.client.cli.param;
|
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.google.common.base.CaseFormat;
|
import com.google.common.base.CaseFormat;
|
||||||
|
@ -21,7 +37,14 @@ import org.apache.hadoop.yarn.api.records.Resource;
|
||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
|
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.RunParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||||
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||||
import org.yaml.snakeyaml.introspector.Property;
|
import org.yaml.snakeyaml.introspector.Property;
|
||||||
import org.yaml.snakeyaml.introspector.PropertyUtils;
|
import org.yaml.snakeyaml.introspector.PropertyUtils;
|
||||||
|
@ -34,27 +57,15 @@ import java.util.List;
|
||||||
/**
|
/**
|
||||||
* Parameters used to run a job
|
* Parameters used to run a job
|
||||||
*/
|
*/
|
||||||
public class RunJobParameters extends RunParameters {
|
public abstract class RunJobParameters extends RunParameters {
|
||||||
private String input;
|
private String input;
|
||||||
private String checkpointPath;
|
private String checkpointPath;
|
||||||
|
|
||||||
private int numWorkers;
|
|
||||||
private int numPS;
|
|
||||||
private Resource workerResource;
|
|
||||||
private Resource psResource;
|
|
||||||
private boolean tensorboardEnabled;
|
|
||||||
private Resource tensorboardResource;
|
|
||||||
private String tensorboardDockerImage;
|
|
||||||
private String workerLaunchCmd;
|
|
||||||
private String psLaunchCmd;
|
|
||||||
private List<Quicklink> quicklinks = new ArrayList<>();
|
private List<Quicklink> quicklinks = new ArrayList<>();
|
||||||
private List<Localization> localizations = new ArrayList<>();
|
private List<Localization> localizations = new ArrayList<>();
|
||||||
|
|
||||||
private String psDockerImage = null;
|
|
||||||
private String workerDockerImage = null;
|
|
||||||
|
|
||||||
private boolean waitJobFinish = false;
|
private boolean waitJobFinish = false;
|
||||||
private boolean distributed = false;
|
protected boolean distributed = false;
|
||||||
|
|
||||||
private boolean securityDisabled = false;
|
private boolean securityDisabled = false;
|
||||||
private String keytab;
|
private String keytab;
|
||||||
|
@ -62,6 +73,9 @@ public class RunJobParameters extends RunParameters {
|
||||||
private boolean distributeKeytab = false;
|
private boolean distributeKeytab = false;
|
||||||
private List<String> confPairs = new ArrayList<>();
|
private List<String> confPairs = new ArrayList<>();
|
||||||
|
|
||||||
|
RoleParameters workerParameters =
|
||||||
|
RoleParameters.createEmpty(TensorFlowRole.WORKER);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void updateParameters(ParametersHolder parametersHolder,
|
public void updateParameters(ParametersHolder parametersHolder,
|
||||||
ClientContext clientContext)
|
ClientContext clientContext)
|
||||||
|
@ -70,34 +84,6 @@ public class RunJobParameters extends RunParameters {
|
||||||
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
|
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
|
||||||
String jobDir = parametersHolder.getOptionValue(
|
String jobDir = parametersHolder.getOptionValue(
|
||||||
CliConstants.CHECKPOINT_PATH);
|
CliConstants.CHECKPOINT_PATH);
|
||||||
int nWorkers = 1;
|
|
||||||
if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
|
|
||||||
nWorkers = Integer.parseInt(
|
|
||||||
parametersHolder.getOptionValue(CliConstants.N_WORKERS));
|
|
||||||
// Only check null value.
|
|
||||||
// Training job shouldn't ignore INPUT_PATH option
|
|
||||||
// But if nWorkers is 0, INPUT_PATH can be ignored because
|
|
||||||
// user can only run Tensorboard
|
|
||||||
if (null == input && 0 != nWorkers) {
|
|
||||||
throw new ParseException("\"--" + CliConstants.INPUT_PATH +
|
|
||||||
"\" is absent");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int nPS = 0;
|
|
||||||
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
|
|
||||||
nPS = Integer.parseInt(
|
|
||||||
parametersHolder.getOptionValue(CliConstants.N_PS));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check #workers and #ps.
|
|
||||||
// When distributed training is required
|
|
||||||
if (nWorkers >= 2 && nPS > 0) {
|
|
||||||
distributed = true;
|
|
||||||
} else if (nWorkers <= 1 && nPS > 0) {
|
|
||||||
throw new ParseException("Only specified one worker but non-zero PS, "
|
|
||||||
+ "please double check.");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parametersHolder.hasOption(CliConstants.INSECURE_CLUSTER)) {
|
if (parametersHolder.hasOption(CliConstants.INSECURE_CLUSTER)) {
|
||||||
setSecurityDisabled(true);
|
setSecurityDisabled(true);
|
||||||
|
@ -109,46 +95,6 @@ public class RunJobParameters extends RunParameters {
|
||||||
CliConstants.PRINCIPAL);
|
CliConstants.PRINCIPAL);
|
||||||
CliUtils.doLoginIfSecure(kerberosKeytab, kerberosPrincipal);
|
CliUtils.doLoginIfSecure(kerberosKeytab, kerberosPrincipal);
|
||||||
|
|
||||||
workerResource = null;
|
|
||||||
if (nWorkers > 0) {
|
|
||||||
String workerResourceStr = parametersHolder.getOptionValue(
|
|
||||||
CliConstants.WORKER_RES);
|
|
||||||
if (workerResourceStr == null) {
|
|
||||||
throw new ParseException(
|
|
||||||
"--" + CliConstants.WORKER_RES + " is absent.");
|
|
||||||
}
|
|
||||||
workerResource = ResourceUtils.createResourceFromString(
|
|
||||||
workerResourceStr,
|
|
||||||
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
Resource psResource = null;
|
|
||||||
if (nPS > 0) {
|
|
||||||
String psResourceStr = parametersHolder.getOptionValue(
|
|
||||||
CliConstants.PS_RES);
|
|
||||||
if (psResourceStr == null) {
|
|
||||||
throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
|
|
||||||
}
|
|
||||||
psResource = ResourceUtils.createResourceFromString(psResourceStr,
|
|
||||||
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean tensorboard = false;
|
|
||||||
if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
|
|
||||||
tensorboard = true;
|
|
||||||
String tensorboardResourceStr = parametersHolder.getOptionValue(
|
|
||||||
CliConstants.TENSORBOARD_RESOURCES);
|
|
||||||
if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
|
|
||||||
tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
|
|
||||||
}
|
|
||||||
tensorboardResource = ResourceUtils.createResourceFromString(
|
|
||||||
tensorboardResourceStr,
|
|
||||||
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
|
||||||
tensorboardDockerImage = parametersHolder.getOptionValue(
|
|
||||||
CliConstants.TENSORBOARD_DOCKER_IMAGE);
|
|
||||||
this.setTensorboardResource(tensorboardResource);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parametersHolder.hasOption(CliConstants.WAIT_JOB_FINISH)) {
|
if (parametersHolder.hasOption(CliConstants.WAIT_JOB_FINISH)) {
|
||||||
this.waitJobFinish = true;
|
this.waitJobFinish = true;
|
||||||
}
|
}
|
||||||
|
@ -164,16 +110,6 @@ public class RunJobParameters extends RunParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
psDockerImage = parametersHolder.getOptionValue(
|
|
||||||
CliConstants.PS_DOCKER_IMAGE);
|
|
||||||
workerDockerImage = parametersHolder.getOptionValue(
|
|
||||||
CliConstants.WORKER_DOCKER_IMAGE);
|
|
||||||
|
|
||||||
String workerLaunchCmd = parametersHolder.getOptionValue(
|
|
||||||
CliConstants.WORKER_LAUNCH_CMD);
|
|
||||||
String psLaunchCommand = parametersHolder.getOptionValue(
|
|
||||||
CliConstants.PS_LAUNCH_CMD);
|
|
||||||
|
|
||||||
// Localizations
|
// Localizations
|
||||||
List<String> localizationsStr = parametersHolder.getOptionValues(
|
List<String> localizationsStr = parametersHolder.getOptionValues(
|
||||||
CliConstants.LOCALIZATION);
|
CliConstants.LOCALIZATION);
|
||||||
|
@ -191,10 +127,6 @@ public class RunJobParameters extends RunParameters {
|
||||||
.getOptionValues(CliConstants.ARG_CONF);
|
.getOptionValues(CliConstants.ARG_CONF);
|
||||||
|
|
||||||
this.setInputPath(input).setCheckpointPath(jobDir)
|
this.setInputPath(input).setCheckpointPath(jobDir)
|
||||||
.setNumPS(nPS).setNumWorkers(nWorkers)
|
|
||||||
.setPSLaunchCmd(psLaunchCommand).setWorkerLaunchCmd(workerLaunchCmd)
|
|
||||||
.setPsResource(psResource)
|
|
||||||
.setTensorboardEnabled(tensorboard)
|
|
||||||
.setKeytab(kerberosKeytab)
|
.setKeytab(kerberosKeytab)
|
||||||
.setPrincipal(kerberosPrincipal)
|
.setPrincipal(kerberosPrincipal)
|
||||||
.setDistributeKeytab(distributeKerberosKeytab)
|
.setDistributeKeytab(distributeKerberosKeytab)
|
||||||
|
@ -203,6 +135,39 @@ public class RunJobParameters extends RunParameters {
|
||||||
super.updateParameters(parametersHolder, clientContext);
|
super.updateParameters(parametersHolder, clientContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
abstract void executePostOperations(ClientContext clientContext)
|
||||||
|
throws IOException;
|
||||||
|
|
||||||
|
void setDefaultDirs(ClientContext clientContext) throws IOException {
|
||||||
|
// Create directories if needed
|
||||||
|
String jobDir = getCheckpointPath();
|
||||||
|
if (jobDir == null) {
|
||||||
|
jobDir = getJobDir(clientContext);
|
||||||
|
setCheckpointPath(jobDir);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getNumWorkers() > 0) {
|
||||||
|
String savedModelDir = getSavedModelPath();
|
||||||
|
if (savedModelDir == null) {
|
||||||
|
savedModelDir = jobDir;
|
||||||
|
setSavedModelPath(savedModelDir);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getJobDir(ClientContext clientContext) throws IOException {
|
||||||
|
RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
|
||||||
|
if (getNumWorkers() > 0) {
|
||||||
|
return rdm.getJobCheckpointDir(getName(), true).toString();
|
||||||
|
} else {
|
||||||
|
// when #workers == 0, it means we only launch TB. In that case,
|
||||||
|
// point job dir to root dir so all job's metrics will be shown.
|
||||||
|
return rdm.getUserRootFolder().toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract List<String> getLaunchCommands();
|
||||||
|
|
||||||
public String getInputPath() {
|
public String getInputPath() {
|
||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
@ -221,110 +186,10 @@ public class RunJobParameters extends RunParameters {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getNumWorkers() {
|
|
||||||
return numWorkers;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RunJobParameters setNumWorkers(int numWorkers) {
|
|
||||||
this.numWorkers = numWorkers;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumPS() {
|
|
||||||
return numPS;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RunJobParameters setNumPS(int numPS) {
|
|
||||||
this.numPS = numPS;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Resource getWorkerResource() {
|
|
||||||
return workerResource;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RunJobParameters setWorkerResource(Resource workerResource) {
|
|
||||||
this.workerResource = workerResource;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Resource getPsResource() {
|
|
||||||
return psResource;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RunJobParameters setPsResource(Resource psResource) {
|
|
||||||
this.psResource = psResource;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isTensorboardEnabled() {
|
|
||||||
return tensorboardEnabled;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RunJobParameters setTensorboardEnabled(boolean tensorboardEnabled) {
|
|
||||||
this.tensorboardEnabled = tensorboardEnabled;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getWorkerLaunchCmd() {
|
|
||||||
return workerLaunchCmd;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RunJobParameters setWorkerLaunchCmd(String workerLaunchCmd) {
|
|
||||||
this.workerLaunchCmd = workerLaunchCmd;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getPSLaunchCmd() {
|
|
||||||
return psLaunchCmd;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RunJobParameters setPSLaunchCmd(String psLaunchCmd) {
|
|
||||||
this.psLaunchCmd = psLaunchCmd;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isWaitJobFinish() {
|
public boolean isWaitJobFinish() {
|
||||||
return waitJobFinish;
|
return waitJobFinish;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public String getPsDockerImage() {
|
|
||||||
return psDockerImage;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setPsDockerImage(String psDockerImage) {
|
|
||||||
this.psDockerImage = psDockerImage;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getWorkerDockerImage() {
|
|
||||||
return workerDockerImage;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setWorkerDockerImage(String workerDockerImage) {
|
|
||||||
this.workerDockerImage = workerDockerImage;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isDistributed() {
|
|
||||||
return distributed;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Resource getTensorboardResource() {
|
|
||||||
return tensorboardResource;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setTensorboardResource(Resource tensorboardResource) {
|
|
||||||
this.tensorboardResource = tensorboardResource;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getTensorboardDockerImage() {
|
|
||||||
return tensorboardDockerImage;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setTensorboardDockerImage(String tensorboardDockerImage) {
|
|
||||||
this.tensorboardDockerImage = tensorboardDockerImage;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Quicklink> getQuicklinks() {
|
public List<Quicklink> getQuicklinks() {
|
||||||
return quicklinks;
|
return quicklinks;
|
||||||
}
|
}
|
||||||
|
@ -382,6 +247,90 @@ public class RunJobParameters extends RunParameters {
|
||||||
this.distributed = distributed;
|
this.distributed = distributed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RoleParameters getWorkerParameters(ClientContext clientContext,
|
||||||
|
ParametersHolder parametersHolder, String input)
|
||||||
|
throws ParseException, YarnException, IOException {
|
||||||
|
int nWorkers = getNumberOfWorkers(parametersHolder, input);
|
||||||
|
Resource workerResource =
|
||||||
|
determineWorkerResource(parametersHolder, nWorkers, clientContext);
|
||||||
|
String workerDockerImage =
|
||||||
|
parametersHolder.getOptionValue(CliConstants.WORKER_DOCKER_IMAGE);
|
||||||
|
String workerLaunchCmd =
|
||||||
|
parametersHolder.getOptionValue(CliConstants.WORKER_LAUNCH_CMD);
|
||||||
|
return new RoleParameters(TensorFlowRole.WORKER, nWorkers,
|
||||||
|
workerLaunchCmd, workerDockerImage, workerResource);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Resource determineWorkerResource(ParametersHolder parametersHolder,
|
||||||
|
int nWorkers, ClientContext clientContext)
|
||||||
|
throws ParseException, YarnException, IOException {
|
||||||
|
if (nWorkers > 0) {
|
||||||
|
String workerResourceStr =
|
||||||
|
parametersHolder.getOptionValue(CliConstants.WORKER_RES);
|
||||||
|
if (workerResourceStr == null) {
|
||||||
|
throw new ParseException(
|
||||||
|
"--" + CliConstants.WORKER_RES + " is absent.");
|
||||||
|
}
|
||||||
|
return ResourceUtils.createResourceFromString(workerResourceStr,
|
||||||
|
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getNumberOfWorkers(ParametersHolder parametersHolder,
|
||||||
|
String input) throws ParseException, YarnException {
|
||||||
|
int nWorkers = 1;
|
||||||
|
if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
|
||||||
|
nWorkers = Integer
|
||||||
|
.parseInt(parametersHolder.getOptionValue(CliConstants.N_WORKERS));
|
||||||
|
// Only check null value.
|
||||||
|
// Training job shouldn't ignore INPUT_PATH option
|
||||||
|
// But if nWorkers is 0, INPUT_PATH can be ignored because
|
||||||
|
// user can only run Tensorboard
|
||||||
|
if (null == input && 0 != nWorkers) {
|
||||||
|
throw new ParseException(
|
||||||
|
"\"--" + CliConstants.INPUT_PATH + "\" is absent");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nWorkers;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getWorkerLaunchCmd() {
|
||||||
|
return workerParameters.getLaunchCommand();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setWorkerLaunchCmd(String launchCmd) {
|
||||||
|
workerParameters.setLaunchCommand(launchCmd);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumWorkers() {
|
||||||
|
return workerParameters.getReplicas();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setNumWorkers(int numWorkers) {
|
||||||
|
workerParameters.setReplicas(numWorkers);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Resource getWorkerResource() {
|
||||||
|
return workerParameters.getResource();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setWorkerResource(Resource resource) {
|
||||||
|
workerParameters.setResource(resource);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getWorkerDockerImage() {
|
||||||
|
return workerParameters.getDockerImage();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setWorkerDockerImage(String image) {
|
||||||
|
workerParameters.setDockerImage(image);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isDistributed() {
|
||||||
|
return distributed;
|
||||||
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
|
public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
|
||||||
@Override
|
@Override
|
|
@ -0,0 +1,215 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import org.apache.commons.cli.ParseException;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.hadoop.yarn.api.records.Resource;
|
||||||
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||||
|
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameters for TensorFlow job.
|
||||||
|
*/
|
||||||
|
public class TensorFlowRunJobParameters extends RunJobParameters {
|
||||||
|
private boolean tensorboardEnabled;
|
||||||
|
private RoleParameters psParameters =
|
||||||
|
RoleParameters.createEmpty(TensorFlowRole.PS);
|
||||||
|
private RoleParameters tensorBoardParameters =
|
||||||
|
RoleParameters.createEmpty(TensorFlowRole.TENSORBOARD);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void updateParameters(ParametersHolder parametersHolder,
|
||||||
|
ClientContext clientContext)
|
||||||
|
throws ParseException, IOException, YarnException {
|
||||||
|
super.updateParameters(parametersHolder, clientContext);
|
||||||
|
|
||||||
|
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
|
||||||
|
this.workerParameters =
|
||||||
|
getWorkerParameters(clientContext, parametersHolder, input);
|
||||||
|
this.psParameters = getPSParameters(clientContext, parametersHolder);
|
||||||
|
this.distributed = determineIfDistributed(workerParameters.getReplicas(),
|
||||||
|
psParameters.getReplicas());
|
||||||
|
|
||||||
|
if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
|
||||||
|
this.tensorboardEnabled = true;
|
||||||
|
this.tensorBoardParameters =
|
||||||
|
getTensorBoardParameters(parametersHolder, clientContext);
|
||||||
|
}
|
||||||
|
executePostOperations(clientContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
void executePostOperations(ClientContext clientContext) throws IOException {
|
||||||
|
// Set default job dir / saved model dir, etc.
|
||||||
|
setDefaultDirs(clientContext);
|
||||||
|
replacePatternsInParameters(clientContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void replacePatternsInParameters(ClientContext clientContext)
|
||||||
|
throws IOException {
|
||||||
|
if (StringUtils.isNotEmpty(getPSLaunchCmd())) {
|
||||||
|
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
|
||||||
|
getPSLaunchCmd(), this, clientContext.getRemoteDirectoryManager());
|
||||||
|
setPSLaunchCmd(afterReplace);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
|
||||||
|
String afterReplace =
|
||||||
|
CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
|
||||||
|
clientContext.getRemoteDirectoryManager());
|
||||||
|
setWorkerLaunchCmd(afterReplace);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getLaunchCommands() {
|
||||||
|
return Lists.newArrayList(getWorkerLaunchCmd(), getPSLaunchCmd());
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean determineIfDistributed(int nWorkers, int nPS)
|
||||||
|
throws ParseException {
|
||||||
|
// Check #workers and #ps.
|
||||||
|
// When distributed training is required
|
||||||
|
if (nWorkers >= 2 && nPS > 0) {
|
||||||
|
return true;
|
||||||
|
} else if (nWorkers <= 1 && nPS > 0) {
|
||||||
|
throw new ParseException("Only specified one worker but non-zero PS, "
|
||||||
|
+ "please double check.");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private RoleParameters getPSParameters(ClientContext clientContext,
|
||||||
|
ParametersHolder parametersHolder)
|
||||||
|
throws YarnException, IOException, ParseException {
|
||||||
|
int nPS = getNumberOfPS(parametersHolder);
|
||||||
|
Resource psResource =
|
||||||
|
determinePSResource(parametersHolder, nPS, clientContext);
|
||||||
|
String psDockerImage =
|
||||||
|
parametersHolder.getOptionValue(CliConstants.PS_DOCKER_IMAGE);
|
||||||
|
String psLaunchCommand =
|
||||||
|
parametersHolder.getOptionValue(CliConstants.PS_LAUNCH_CMD);
|
||||||
|
return new RoleParameters(TensorFlowRole.PS, nPS, psLaunchCommand,
|
||||||
|
psDockerImage, psResource);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Resource determinePSResource(ParametersHolder parametersHolder,
|
||||||
|
int nPS, ClientContext clientContext)
|
||||||
|
throws ParseException, YarnException, IOException {
|
||||||
|
if (nPS > 0) {
|
||||||
|
String psResourceStr =
|
||||||
|
parametersHolder.getOptionValue(CliConstants.PS_RES);
|
||||||
|
if (psResourceStr == null) {
|
||||||
|
throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
|
||||||
|
}
|
||||||
|
return ResourceUtils.createResourceFromString(psResourceStr,
|
||||||
|
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getNumberOfPS(ParametersHolder parametersHolder)
|
||||||
|
throws YarnException {
|
||||||
|
int nPS = 0;
|
||||||
|
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
|
||||||
|
nPS =
|
||||||
|
Integer.parseInt(parametersHolder.getOptionValue(CliConstants.N_PS));
|
||||||
|
}
|
||||||
|
return nPS;
|
||||||
|
}
|
||||||
|
|
||||||
|
private RoleParameters getTensorBoardParameters(
|
||||||
|
ParametersHolder parametersHolder, ClientContext clientContext)
|
||||||
|
throws YarnException, IOException {
|
||||||
|
String tensorboardResourceStr =
|
||||||
|
parametersHolder.getOptionValue(CliConstants.TENSORBOARD_RESOURCES);
|
||||||
|
if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
|
||||||
|
tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
|
||||||
|
}
|
||||||
|
Resource tensorboardResource =
|
||||||
|
ResourceUtils.createResourceFromString(tensorboardResourceStr,
|
||||||
|
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
|
||||||
|
String tensorboardDockerImage =
|
||||||
|
parametersHolder.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE);
|
||||||
|
return new RoleParameters(TensorFlowRole.TENSORBOARD, 1, null,
|
||||||
|
tensorboardDockerImage, tensorboardResource);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumPS() {
|
||||||
|
return psParameters.getReplicas();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setNumPS(int numPS) {
|
||||||
|
psParameters.setReplicas(numPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Resource getPsResource() {
|
||||||
|
return psParameters.getResource();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setPsResource(Resource resource) {
|
||||||
|
psParameters.setResource(resource);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getPsDockerImage() {
|
||||||
|
return psParameters.getDockerImage();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setPsDockerImage(String image) {
|
||||||
|
psParameters.setDockerImage(image);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getPSLaunchCmd() {
|
||||||
|
return psParameters.getLaunchCommand();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setPSLaunchCmd(String launchCmd) {
|
||||||
|
psParameters.setLaunchCommand(launchCmd);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isTensorboardEnabled() {
|
||||||
|
return tensorboardEnabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Resource getTensorboardResource() {
|
||||||
|
return tensorBoardParameters.getResource();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTensorboardResource(Resource resource) {
|
||||||
|
tensorBoardParameters.setResource(resource);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getTensorboardDockerImage() {
|
||||||
|
return tensorBoardParameters.getDockerImage();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTensorboardDockerImage(String image) {
|
||||||
|
tensorBoardParameters.setDockerImage(image);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* This package contains classes that hold run job parameters for
|
||||||
|
* TensorFlow / PyTorch jobs.
|
||||||
|
*/
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
|
|
@ -22,6 +22,7 @@ package org.apache.hadoop.yarn.submarine.client.cli.param.yaml;
|
||||||
public class Spec {
|
public class Spec {
|
||||||
private String name;
|
private String name;
|
||||||
private String jobType;
|
private String jobType;
|
||||||
|
private String framework;
|
||||||
|
|
||||||
public String getJobType() {
|
public String getJobType() {
|
||||||
return jobType;
|
return jobType;
|
||||||
|
@ -38,4 +39,12 @@ public class Spec {
|
||||||
public void setName(String name) {
|
public void setName(String name) {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String getFramework() {
|
||||||
|
return framework;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setFramework(String framework) {
|
||||||
|
this.framework = framework;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the type of Machine learning framework to work with.
|
||||||
|
*/
|
||||||
|
public enum Framework {
|
||||||
|
TENSORFLOW(Constants.TENSORFLOW_NAME), PYTORCH(Constants.PYTORCH_NAME);
|
||||||
|
|
||||||
|
private String value;
|
||||||
|
|
||||||
|
Framework(String value) {
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Framework parseByValue(String value) {
|
||||||
|
for (Framework fw : Framework.values()) {
|
||||||
|
if (fw.value.equalsIgnoreCase(value)) {
|
||||||
|
return fw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getValues() {
|
||||||
|
List<String> values = Lists.newArrayList(Framework.values()).stream()
|
||||||
|
.map(fw -> fw.value).collect(Collectors.toList());
|
||||||
|
return String.join(",", values);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class Constants {
|
||||||
|
static final String TENSORFLOW_NAME = "tensorflow";
|
||||||
|
static final String PYTORCH_NAME = "pytorch";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||||
|
|
||||||
|
import org.apache.hadoop.yarn.api.records.Resource;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class encapsulates data related to a particular Role.
|
||||||
|
* Some examples: TF Worker process, TF PS process or a PyTorch worker process.
|
||||||
|
*/
|
||||||
|
public class RoleParameters {
|
||||||
|
private final Role role;
|
||||||
|
private int replicas;
|
||||||
|
private String launchCommand;
|
||||||
|
private String dockerImage;
|
||||||
|
private Resource resource;
|
||||||
|
|
||||||
|
public RoleParameters(Role role, int replicas,
|
||||||
|
String launchCommand, String dockerImage, Resource resource) {
|
||||||
|
this.role = role;
|
||||||
|
this.replicas = replicas;
|
||||||
|
this.launchCommand = launchCommand;
|
||||||
|
this.dockerImage = dockerImage;
|
||||||
|
this.resource = resource;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RoleParameters createEmpty(Role role) {
|
||||||
|
return new RoleParameters(role, 0, null, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Role getRole() {
|
||||||
|
return role;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getReplicas() {
|
||||||
|
return replicas;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getLaunchCommand() {
|
||||||
|
return launchCommand;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLaunchCommand(String launchCommand) {
|
||||||
|
this.launchCommand = launchCommand;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getDockerImage() {
|
||||||
|
return dockerImage;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setDockerImage(String dockerImage) {
|
||||||
|
this.dockerImage = dockerImage;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Resource getResource() {
|
||||||
|
return resource;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setResource(Resource resource) {
|
||||||
|
this.resource = resource;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setReplicas(int replicas) {
|
||||||
|
this.replicas = replicas;
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,19 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -12,7 +28,7 @@
|
||||||
* limitations under the License. See accompanying LICENSE file.
|
* limitations under the License. See accompanying LICENSE file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import org.apache.commons.cli.CommandLine;
|
import org.apache.commons.cli.CommandLine;
|
||||||
|
@ -23,9 +39,13 @@ import org.apache.commons.cli.ParseException;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.AbstractCli;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.Command;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
||||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
|
@ -44,17 +64,25 @@ import java.io.IOException;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This purpose of this class is to handle / parse CLI arguments related to
|
||||||
|
* the run job Submarine command.
|
||||||
|
*/
|
||||||
public class RunJobCli extends AbstractCli {
|
public class RunJobCli extends AbstractCli {
|
||||||
private static final Logger LOG =
|
private static final Logger LOG =
|
||||||
LoggerFactory.getLogger(RunJobCli.class);
|
LoggerFactory.getLogger(RunJobCli.class);
|
||||||
private static final String YAML_PARSE_FAILED = "Failed to parse " +
|
private static final String CAN_BE_USED_WITH_TF_PYTORCH =
|
||||||
|
"Can be used with TensorFlow or PyTorch frameworks.";
|
||||||
|
private static final String CAN_BE_USED_WITH_TF_ONLY =
|
||||||
|
"Can only be used with TensorFlow framework.";
|
||||||
|
public static final String YAML_PARSE_FAILED = "Failed to parse " +
|
||||||
"YAML config";
|
"YAML config";
|
||||||
|
|
||||||
private Options options;
|
|
||||||
private RunJobParameters parameters = new RunJobParameters();
|
|
||||||
|
|
||||||
|
private Options options;
|
||||||
private JobSubmitter jobSubmitter;
|
private JobSubmitter jobSubmitter;
|
||||||
private JobMonitor jobMonitor;
|
private JobMonitor jobMonitor;
|
||||||
|
private ParametersHolder parametersHolder;
|
||||||
|
|
||||||
public RunJobCli(ClientContext cliContext) {
|
public RunJobCli(ClientContext cliContext) {
|
||||||
this(cliContext, cliContext.getRuntimeFactory().getJobSubmitterInstance(),
|
this(cliContext, cliContext.getRuntimeFactory().getJobSubmitterInstance(),
|
||||||
|
@ -62,7 +90,7 @@ public class RunJobCli extends AbstractCli {
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
|
public RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
|
||||||
JobMonitor jobMonitor) {
|
JobMonitor jobMonitor) {
|
||||||
super(cliContext);
|
super(cliContext);
|
||||||
this.options = generateOptions();
|
this.options = generateOptions();
|
||||||
|
@ -78,6 +106,10 @@ public class RunJobCli extends AbstractCli {
|
||||||
Options options = new Options();
|
Options options = new Options();
|
||||||
options.addOption(CliConstants.YAML_CONFIG, true,
|
options.addOption(CliConstants.YAML_CONFIG, true,
|
||||||
"Config file (in YAML format)");
|
"Config file (in YAML format)");
|
||||||
|
options.addOption(CliConstants.FRAMEWORK, true,
|
||||||
|
String.format("Framework to use. Valid values are: %s! " +
|
||||||
|
"The default framework is Tensorflow.",
|
||||||
|
Framework.getValues()));
|
||||||
options.addOption(CliConstants.NAME, true, "Name of the job");
|
options.addOption(CliConstants.NAME, true, "Name of the job");
|
||||||
options.addOption(CliConstants.INPUT_PATH, true,
|
options.addOption(CliConstants.INPUT_PATH, true,
|
||||||
"Input of the job, could be local or other FS directory");
|
"Input of the job, could be local or other FS directory");
|
||||||
|
@ -88,48 +120,22 @@ public class RunJobCli extends AbstractCli {
|
||||||
options.addOption(CliConstants.SAVED_MODEL_PATH, true,
|
options.addOption(CliConstants.SAVED_MODEL_PATH, true,
|
||||||
"Model exported path (savedmodel) of the job, which is needed when "
|
"Model exported path (savedmodel) of the job, which is needed when "
|
||||||
+ "exported model is not placed under ${checkpoint_path}"
|
+ "exported model is not placed under ${checkpoint_path}"
|
||||||
+ "could be local or other FS directory. This will be used to serve.");
|
+ "could be local or other FS directory. " +
|
||||||
options.addOption(CliConstants.N_WORKERS, true,
|
"This will be used to serve.");
|
||||||
"Number of worker tasks of the job, by default it's 1");
|
|
||||||
options.addOption(CliConstants.N_PS, true,
|
|
||||||
"Number of PS tasks of the job, by default it's 0");
|
|
||||||
options.addOption(CliConstants.WORKER_RES, true,
|
|
||||||
"Resource of each worker, for example "
|
|
||||||
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2");
|
|
||||||
options.addOption(CliConstants.PS_RES, true,
|
|
||||||
"Resource of each PS, for example "
|
|
||||||
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2");
|
|
||||||
options.addOption(CliConstants.DOCKER_IMAGE, true, "Docker image name/tag");
|
options.addOption(CliConstants.DOCKER_IMAGE, true, "Docker image name/tag");
|
||||||
options.addOption(CliConstants.QUEUE, true,
|
options.addOption(CliConstants.QUEUE, true,
|
||||||
"Name of queue to run the job, by default it uses default queue");
|
"Name of queue to run the job, by default it uses default queue");
|
||||||
options.addOption(CliConstants.TENSORBOARD, false,
|
|
||||||
"Should we run TensorBoard"
|
addWorkerOptions(options);
|
||||||
+ " for this job? By default it's disabled");
|
addPSOptions(options);
|
||||||
options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
|
addTensorboardOptions(options);
|
||||||
"Specify resources of Tensorboard, by default it is "
|
|
||||||
+ CliConstants.TENSORBOARD_DEFAULT_RESOURCES);
|
|
||||||
options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
|
|
||||||
"Specify Tensorboard docker image. when this is not "
|
|
||||||
+ "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
|
|
||||||
+ " as default.");
|
|
||||||
options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
|
|
||||||
"Commandline of worker, arguments will be "
|
|
||||||
+ "directly used to launch the worker");
|
|
||||||
options.addOption(CliConstants.PS_LAUNCH_CMD, true,
|
|
||||||
"Commandline of worker, arguments will be "
|
|
||||||
+ "directly used to launch the PS");
|
|
||||||
options.addOption(CliConstants.ENV, true,
|
options.addOption(CliConstants.ENV, true,
|
||||||
"Common environment variable of worker/ps");
|
"Common environment variable of worker/ps");
|
||||||
options.addOption(CliConstants.VERBOSE, false,
|
options.addOption(CliConstants.VERBOSE, false,
|
||||||
"Print verbose log for troubleshooting");
|
"Print verbose log for troubleshooting");
|
||||||
options.addOption(CliConstants.WAIT_JOB_FINISH, false,
|
options.addOption(CliConstants.WAIT_JOB_FINISH, false,
|
||||||
"Specified when user want to wait the job finish");
|
"Specified when user want to wait the job finish");
|
||||||
options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
|
|
||||||
"Specify docker image for PS, when this is not specified, PS uses --"
|
|
||||||
+ CliConstants.DOCKER_IMAGE + " as default.");
|
|
||||||
options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
|
|
||||||
"Specify docker image for WORKER, when this is not specified, WORKER "
|
|
||||||
+ "uses --" + CliConstants.DOCKER_IMAGE + " as default.");
|
|
||||||
options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN"
|
options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN"
|
||||||
+ "web UI shows link to given role instance and port. When "
|
+ "web UI shows link to given role instance and port. When "
|
||||||
+ "--tensorboard is specified, quicklink to tensorboard instance will "
|
+ "--tensorboard is specified, quicklink to tensorboard instance will "
|
||||||
|
@ -172,63 +178,97 @@ public class RunJobCli extends AbstractCli {
|
||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void replacePatternsInParameters() throws IOException {
|
private void addWorkerOptions(Options options) {
|
||||||
if (parameters.getPSLaunchCmd() != null && !parameters.getPSLaunchCmd()
|
options.addOption(CliConstants.N_WORKERS, true,
|
||||||
.isEmpty()) {
|
"Number of worker tasks of the job, by default it's 1." +
|
||||||
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
|
CAN_BE_USED_WITH_TF_PYTORCH);
|
||||||
parameters.getPSLaunchCmd(), parameters,
|
options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
|
||||||
clientContext.getRemoteDirectoryManager());
|
"Specify docker image for WORKER, when this is not specified, WORKER "
|
||||||
parameters.setPSLaunchCmd(afterReplace);
|
+ "uses --" + CliConstants.DOCKER_IMAGE + " as default." +
|
||||||
}
|
CAN_BE_USED_WITH_TF_PYTORCH);
|
||||||
|
options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
|
||||||
|
"Commandline of worker, arguments will be "
|
||||||
|
+ "directly used to launch the worker" +
|
||||||
|
CAN_BE_USED_WITH_TF_PYTORCH);
|
||||||
|
options.addOption(CliConstants.WORKER_RES, true,
|
||||||
|
"Resource of each worker, for example "
|
||||||
|
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
|
||||||
|
CAN_BE_USED_WITH_TF_PYTORCH);
|
||||||
|
}
|
||||||
|
|
||||||
if (parameters.getWorkerLaunchCmd() != null && !parameters
|
private void addPSOptions(Options options) {
|
||||||
.getWorkerLaunchCmd().isEmpty()) {
|
options.addOption(CliConstants.N_PS, true,
|
||||||
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
|
"Number of PS tasks of the job, by default it's 0. " +
|
||||||
parameters.getWorkerLaunchCmd(), parameters,
|
CAN_BE_USED_WITH_TF_ONLY);
|
||||||
clientContext.getRemoteDirectoryManager());
|
options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
|
||||||
parameters.setWorkerLaunchCmd(afterReplace);
|
"Specify docker image for PS, when this is not specified, PS uses --"
|
||||||
}
|
+ CliConstants.DOCKER_IMAGE + " as default." +
|
||||||
|
CAN_BE_USED_WITH_TF_ONLY);
|
||||||
|
options.addOption(CliConstants.PS_LAUNCH_CMD, true,
|
||||||
|
"Commandline of worker, arguments will be "
|
||||||
|
+ "directly used to launch the PS" +
|
||||||
|
CAN_BE_USED_WITH_TF_ONLY);
|
||||||
|
options.addOption(CliConstants.PS_RES, true,
|
||||||
|
"Resource of each PS, for example "
|
||||||
|
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
|
||||||
|
CAN_BE_USED_WITH_TF_ONLY);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addTensorboardOptions(Options options) {
|
||||||
|
options.addOption(CliConstants.TENSORBOARD, false,
|
||||||
|
"Should we run TensorBoard"
|
||||||
|
+ " for this job? By default it's disabled." +
|
||||||
|
CAN_BE_USED_WITH_TF_ONLY);
|
||||||
|
options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
|
||||||
|
"Specify resources of Tensorboard, by default it is "
|
||||||
|
+ CliConstants.TENSORBOARD_DEFAULT_RESOURCES + "." +
|
||||||
|
CAN_BE_USED_WITH_TF_ONLY);
|
||||||
|
options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
|
||||||
|
"Specify Tensorboard docker image. when this is not "
|
||||||
|
+ "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
|
||||||
|
+ " as default." +
|
||||||
|
CAN_BE_USED_WITH_TF_ONLY);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void parseCommandLineAndGetRunJobParameters(String[] args)
|
private void parseCommandLineAndGetRunJobParameters(String[] args)
|
||||||
throws ParseException, IOException, YarnException {
|
throws ParseException, IOException, YarnException {
|
||||||
try {
|
try {
|
||||||
// Do parsing
|
|
||||||
GnuParser parser = new GnuParser();
|
GnuParser parser = new GnuParser();
|
||||||
CommandLine cli = parser.parse(options, args);
|
CommandLine cli = parser.parse(options, args);
|
||||||
ParametersHolder parametersHolder = createParametersHolder(cli);
|
parametersHolder = createParametersHolder(cli);
|
||||||
parameters.updateParameters(parametersHolder, clientContext);
|
parametersHolder.updateParameters(clientContext);
|
||||||
} catch (ParseException e) {
|
} catch (ParseException e) {
|
||||||
LOG.error("Exception in parse: {}", e.getMessage());
|
LOG.error("Exception in parse: {}", e.getMessage());
|
||||||
printUsages();
|
printUsages();
|
||||||
throw e;
|
throw e;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default job dir / saved model dir, etc.
|
|
||||||
setDefaultDirs();
|
|
||||||
|
|
||||||
// replace patterns
|
|
||||||
replacePatternsInParameters();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private ParametersHolder createParametersHolder(CommandLine cli) {
|
private ParametersHolder createParametersHolder(CommandLine cli)
|
||||||
|
throws ParseException, YarnException {
|
||||||
String yamlConfigFile =
|
String yamlConfigFile =
|
||||||
cli.getOptionValue(CliConstants.YAML_CONFIG);
|
cli.getOptionValue(CliConstants.YAML_CONFIG);
|
||||||
if (yamlConfigFile != null) {
|
if (yamlConfigFile != null) {
|
||||||
YamlConfigFile yamlConfig = readYamlConfigFile(yamlConfigFile);
|
YamlConfigFile yamlConfig = readYamlConfigFile(yamlConfigFile);
|
||||||
if (yamlConfig == null) {
|
checkYamlConfig(yamlConfigFile, yamlConfig);
|
||||||
throw new YamlParseException(String.format(
|
|
||||||
YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
|
|
||||||
} else if (yamlConfig.getConfigs() == null) {
|
|
||||||
throw new YamlParseException(String.format(YAML_PARSE_FAILED +
|
|
||||||
", config section should be defined, but it cannot be found in " +
|
|
||||||
"YAML file '%s'!", yamlConfigFile));
|
|
||||||
}
|
|
||||||
LOG.info("Using YAML configuration!");
|
LOG.info("Using YAML configuration!");
|
||||||
return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig);
|
return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig,
|
||||||
|
Command.RUN_JOB);
|
||||||
} else {
|
} else {
|
||||||
LOG.info("Using CLI configuration!");
|
LOG.info("Using CLI configuration!");
|
||||||
return ParametersHolder.createWithCmdLine(cli);
|
return ParametersHolder.createWithCmdLine(cli, Command.RUN_JOB);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkYamlConfig(String yamlConfigFile,
|
||||||
|
YamlConfigFile yamlConfig) {
|
||||||
|
if (yamlConfig == null) {
|
||||||
|
throw new YamlParseException(String.format(
|
||||||
|
YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
|
||||||
|
} else if (yamlConfig.getConfigs() == null) {
|
||||||
|
throw new YamlParseException(String.format(YAML_PARSE_FAILED +
|
||||||
|
", config section should be defined, but it cannot be found in " +
|
||||||
|
"YAML file '%s'!", yamlConfigFile));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,34 +296,9 @@ public class RunJobCli extends AbstractCli {
|
||||||
e);
|
e);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setDefaultDirs() throws IOException {
|
private void storeJobInformation(RunJobParameters parameters,
|
||||||
// Create directories if needed
|
ApplicationId applicationId, String[] args) throws IOException {
|
||||||
String jobDir = parameters.getCheckpointPath();
|
String jobName = parameters.getName();
|
||||||
if (null == jobDir) {
|
|
||||||
if (parameters.getNumWorkers() > 0) {
|
|
||||||
jobDir = clientContext.getRemoteDirectoryManager().getJobCheckpointDir(
|
|
||||||
parameters.getName(), true).toString();
|
|
||||||
} else {
|
|
||||||
// when #workers == 0, it means we only launch TB. In that case,
|
|
||||||
// point job dir to root dir so all job's metrics will be shown.
|
|
||||||
jobDir = clientContext.getRemoteDirectoryManager().getUserRootFolder()
|
|
||||||
.toString();
|
|
||||||
}
|
|
||||||
parameters.setCheckpointPath(jobDir);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parameters.getNumWorkers() > 0) {
|
|
||||||
// Only do this when #worker > 0
|
|
||||||
String savedModelDir = parameters.getSavedModelPath();
|
|
||||||
if (null == savedModelDir) {
|
|
||||||
savedModelDir = jobDir;
|
|
||||||
parameters.setSavedModelPath(savedModelDir);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void storeJobInformation(String jobName, ApplicationId applicationId,
|
|
||||||
String[] args) throws IOException {
|
|
||||||
Map<String, String> jobInfo = new HashMap<>();
|
Map<String, String> jobInfo = new HashMap<>();
|
||||||
jobInfo.put(StorageKeyConstants.JOB_NAME, jobName);
|
jobInfo.put(StorageKeyConstants.JOB_NAME, jobName);
|
||||||
jobInfo.put(StorageKeyConstants.APPLICATION_ID, applicationId.toString());
|
jobInfo.put(StorageKeyConstants.APPLICATION_ID, applicationId.toString());
|
||||||
|
@ -316,8 +331,10 @@ public class RunJobCli extends AbstractCli {
|
||||||
}
|
}
|
||||||
|
|
||||||
parseCommandLineAndGetRunJobParameters(args);
|
parseCommandLineAndGetRunJobParameters(args);
|
||||||
ApplicationId applicationId = this.jobSubmitter.submitJob(parameters);
|
ApplicationId applicationId = jobSubmitter.submitJob(parametersHolder);
|
||||||
storeJobInformation(parameters.getName(), applicationId, args);
|
RunJobParameters parameters =
|
||||||
|
(RunJobParameters) parametersHolder.getParameters();
|
||||||
|
storeJobInformation(parameters, applicationId, args);
|
||||||
if (parameters.isWaitJobFinish()) {
|
if (parameters.isWaitJobFinish()) {
|
||||||
this.jobMonitor.waitTrainingFinal(parameters.getName());
|
this.jobMonitor.waitTrainingFinal(parameters.getName());
|
||||||
}
|
}
|
||||||
|
@ -332,6 +349,6 @@ public class RunJobCli extends AbstractCli {
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
public RunJobParameters getRunJobParameters() {
|
public RunJobParameters getRunJobParameters() {
|
||||||
return parameters;
|
return (RunJobParameters) parametersHolder.getParameters();
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* This package contains classes that are related to the run job command.
|
||||||
|
*/
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
|
@ -0,0 +1,54 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. See accompanying LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.common.api;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Enum to represent a PyTorch Role.
|
||||||
|
*/
|
||||||
|
public enum PyTorchRole implements Role {
|
||||||
|
PRIMARY_WORKER("master"),
|
||||||
|
WORKER("worker");
|
||||||
|
|
||||||
|
private String compName;
|
||||||
|
|
||||||
|
PyTorchRole(String compName) {
|
||||||
|
this.compName = compName;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getComponentName() {
|
||||||
|
return compName;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return name();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.common.api;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for a Role.
|
||||||
|
*/
|
||||||
|
public interface Role {
|
||||||
|
String getComponentName();
|
||||||
|
String getName();
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.common.api;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the type of Runtime.
|
||||||
|
*/
|
||||||
|
public enum Runtime {
|
||||||
|
TONY(Constants.TONY), YARN_SERVICE(Constants.YARN_SERVICE);
|
||||||
|
|
||||||
|
private String value;
|
||||||
|
|
||||||
|
Runtime(String value) {
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Runtime parseByValue(String value) {
|
||||||
|
for (Runtime rt : Runtime.values()) {
|
||||||
|
if (rt.value.equalsIgnoreCase(value)) {
|
||||||
|
return rt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getValues() {
|
||||||
|
List<String> values = Lists.newArrayList(Runtime.values()).stream()
|
||||||
|
.map(rt -> rt.value).collect(Collectors.toList());
|
||||||
|
return String.join(",", values);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Constants {
|
||||||
|
public static final String TONY = "tony";
|
||||||
|
public static final String YARN_SERVICE = "yarnservice";
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,7 +14,10 @@
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.submarine.common.api;
|
package org.apache.hadoop.yarn.submarine.common.api;
|
||||||
|
|
||||||
public enum TaskType {
|
/**
|
||||||
|
* Enum to represent a TensorFlow Role.
|
||||||
|
*/
|
||||||
|
public enum TensorFlowRole implements Role {
|
||||||
PRIMARY_WORKER("master"),
|
PRIMARY_WORKER("master"),
|
||||||
WORKER("worker"),
|
WORKER("worker"),
|
||||||
PS("ps"),
|
PS("ps"),
|
||||||
|
@ -22,11 +25,17 @@ public enum TaskType {
|
||||||
|
|
||||||
private String compName;
|
private String compName;
|
||||||
|
|
||||||
TaskType(String compName) {
|
TensorFlowRole(String compName) {
|
||||||
this.compName = compName;
|
this.compName = compName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public String getComponentName() {
|
public String getComponentName() {
|
||||||
return compName;
|
return compName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return name();
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -16,21 +16,21 @@ package org.apache.hadoop.yarn.submarine.runtimes.common;
|
||||||
|
|
||||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Submit job to cluster master
|
* Submit job to cluster master.
|
||||||
*/
|
*/
|
||||||
public interface JobSubmitter {
|
public interface JobSubmitter {
|
||||||
/**
|
/**
|
||||||
* Submit job to cluster
|
* Submit a job to cluster.
|
||||||
* @param parameters run job parameters
|
* @param parameters run job parameters
|
||||||
* @return applicatioId when successfully submitted
|
* @return applicationId when successfully submitted
|
||||||
* @throws YarnException for issues while contacting YARN daemons
|
* @throws YarnException for issues while contacting YARN daemons
|
||||||
* @throws IOException for other issues.
|
* @throws IOException for other issues.
|
||||||
*/
|
*/
|
||||||
ApplicationId submitJob(RunJobParameters parameters)
|
ApplicationId submitJob(ParametersHolder parameters)
|
||||||
throws IOException, YarnException;
|
throws IOException, YarnException;
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,6 +40,10 @@ More details, please refer to
|
||||||
|
|
||||||
```$xslt
|
```$xslt
|
||||||
usage: job run
|
usage: job run
|
||||||
|
|
||||||
|
-framework <arg> Framework to use.
|
||||||
|
Valid values are: tensorflow, pytorch.
|
||||||
|
The default framework is Tensorflow.
|
||||||
-checkpoint_path <arg> Training output directory of the job, could
|
-checkpoint_path <arg> Training output directory of the job, could
|
||||||
be local or other FS directory. This
|
be local or other FS directory. This
|
||||||
typically includes checkpoint files and
|
typically includes checkpoint files and
|
||||||
|
@ -130,6 +134,7 @@ For submarine internal configuration, please create a `submarine.xml` which shou
|
||||||
#### Commandline
|
#### Commandline
|
||||||
```
|
```
|
||||||
yarn jar path-to/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar job run \
|
yarn jar path-to/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar job run \
|
||||||
|
--framework tensorflow \
|
||||||
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
|
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
|
||||||
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current --name tf-job-001 \
|
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current --name tf-job-001 \
|
||||||
--docker_image <your-docker-image> \
|
--docker_image <your-docker-image> \
|
||||||
|
@ -163,6 +168,7 @@ See below screenshot:
|
||||||
```
|
```
|
||||||
yarn jar hadoop-yarn-applications-submarine-<version>.jar job run \
|
yarn jar hadoop-yarn-applications-submarine-<version>.jar job run \
|
||||||
--name tf-job-001 --docker_image <your-docker-image> \
|
--name tf-job-001 --docker_image <your-docker-image> \
|
||||||
|
--framework tensorflow \
|
||||||
--input_path hdfs://default/dataset/cifar-10-data \
|
--input_path hdfs://default/dataset/cifar-10-data \
|
||||||
--checkpoint_path hdfs://default/tmp/cifar-10-jobdir \
|
--checkpoint_path hdfs://default/tmp/cifar-10-jobdir \
|
||||||
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
|
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
|
||||||
|
@ -208,6 +214,7 @@ After that, you can run ```tensorboard --logdir=<checkpoint-path>``` to view Ten
|
||||||
yarn app -destroy tensorboard-service; \
|
yarn app -destroy tensorboard-service; \
|
||||||
yarn jar /tmp/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar \
|
yarn jar /tmp/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar \
|
||||||
job run --name tensorboard-service --verbose --docker_image <your-docker-image> \
|
job run --name tensorboard-service --verbose --docker_image <your-docker-image> \
|
||||||
|
--framework tensorflow \
|
||||||
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
|
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
|
||||||
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current \
|
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current \
|
||||||
--num_workers 0 --tensorboard
|
--num_workers 0 --tensorboard
|
||||||
|
|
|
@ -1,226 +0,0 @@
|
||||||
/**
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
* or more contributor license agreements. See the NOTICE file
|
|
||||||
* distributed with this work for additional information
|
|
||||||
* regarding copyright ownership. The ASF licenses this file
|
|
||||||
* to you under the Apache License, Version 2.0 (the
|
|
||||||
* "License"); you may not use this file except in compliance
|
|
||||||
* with the License. You may obtain a copy of the License at
|
|
||||||
* <p>
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
* <p>
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
|
||||||
|
|
||||||
import org.apache.commons.cli.ParseException;
|
|
||||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
|
||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
|
||||||
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
|
|
||||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
|
|
||||||
import org.apache.hadoop.yarn.util.resource.Resources;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class TestRunJobCliParsing {
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void before() {
|
|
||||||
SubmarineLogs.verboseOff();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testPrintHelp() {
|
|
||||||
MockClientContext mockClientContext = new MockClientContext();
|
|
||||||
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
|
||||||
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
|
||||||
RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
|
|
||||||
mockJobMonitor);
|
|
||||||
runJobCli.printUsages();
|
|
||||||
}
|
|
||||||
|
|
||||||
static MockClientContext getMockClientContext()
|
|
||||||
throws IOException, YarnException {
|
|
||||||
MockClientContext mockClientContext = new MockClientContext();
|
|
||||||
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
|
||||||
when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
|
|
||||||
ApplicationId.newInstance(1234L, 1));
|
|
||||||
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
|
||||||
SubmarineStorage storage = mock(SubmarineStorage.class);
|
|
||||||
RuntimeFactory rtFactory = mock(RuntimeFactory.class);
|
|
||||||
|
|
||||||
when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
|
|
||||||
when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
|
|
||||||
when(rtFactory.getSubmarineStorage()).thenReturn(storage);
|
|
||||||
|
|
||||||
mockClientContext.setRuntimeFactory(rtFactory);
|
|
||||||
return mockClientContext;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBasicRunJobForDistributedTraining() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
|
|
||||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
|
||||||
|
|
||||||
runJobCli.run(
|
|
||||||
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
|
||||||
"--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
|
|
||||||
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
|
|
||||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
|
||||||
"--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
|
|
||||||
"--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
|
|
||||||
"--principal", "user/_HOST@domain.com", "--distribute_keytab",
|
|
||||||
"--verbose" });
|
|
||||||
|
|
||||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
|
||||||
|
|
||||||
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
|
||||||
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
|
||||||
assertEquals(jobRunParameters.getNumPS(), 2);
|
|
||||||
assertEquals(jobRunParameters.getPSLaunchCmd(), "python run-ps.py");
|
|
||||||
assertEquals(Resources.createResource(4096, 4),
|
|
||||||
jobRunParameters.getPsResource());
|
|
||||||
assertEquals(jobRunParameters.getWorkerLaunchCmd(),
|
|
||||||
"python run-job.py");
|
|
||||||
assertEquals(Resources.createResource(2048, 2),
|
|
||||||
jobRunParameters.getWorkerResource());
|
|
||||||
assertEquals(jobRunParameters.getDockerImageName(),
|
|
||||||
"tf-docker:1.1.0");
|
|
||||||
assertEquals(jobRunParameters.getKeytab(),
|
|
||||||
"/keytab/path");
|
|
||||||
assertEquals(jobRunParameters.getPrincipal(),
|
|
||||||
"user/_HOST@domain.com");
|
|
||||||
Assert.assertTrue(jobRunParameters.isDistributeKeytab());
|
|
||||||
Assert.assertTrue(SubmarineLogs.isVerbose());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBasicRunJobForSingleNodeTraining() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
|
||||||
|
|
||||||
runJobCli.run(
|
|
||||||
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
|
||||||
"--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
|
|
||||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
|
||||||
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
|
||||||
"true", "--verbose", "--wait_job_finish" });
|
|
||||||
|
|
||||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
|
||||||
|
|
||||||
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
|
||||||
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
|
||||||
assertEquals(jobRunParameters.getNumWorkers(), 1);
|
|
||||||
assertEquals(jobRunParameters.getWorkerLaunchCmd(),
|
|
||||||
"python run-job.py");
|
|
||||||
assertEquals(Resources.createResource(4096, 2),
|
|
||||||
jobRunParameters.getWorkerResource());
|
|
||||||
Assert.assertTrue(SubmarineLogs.isVerbose());
|
|
||||||
Assert.assertTrue(jobRunParameters.isWaitJobFinish());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNoInputPathOptionSpecified() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\" is absent";
|
|
||||||
String actualMessage = "";
|
|
||||||
try {
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
|
||||||
"--checkpoint_path", "hdfs://output",
|
|
||||||
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
|
||||||
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
|
||||||
"true", "--verbose", "--wait_job_finish"});
|
|
||||||
} catch (ParseException e) {
|
|
||||||
actualMessage = e.getMessage();
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
assertEquals(expectedErrorMessage, actualMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* when only run tensorboard, input_path is not needed
|
|
||||||
* */
|
|
||||||
@Test
|
|
||||||
public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
boolean success = true;
|
|
||||||
try {
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
|
||||||
"--num_workers", "0", "--tensorboard", "--verbose",
|
|
||||||
"--tensorboard_resources", "memory=2G,vcores=2",
|
|
||||||
"--tensorboard_docker_image", "tb_docker_image:001"});
|
|
||||||
} catch (ParseException e) {
|
|
||||||
success = false;
|
|
||||||
}
|
|
||||||
Assert.assertTrue(success);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testJobWithoutName() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
String expectedErrorMessage =
|
|
||||||
"--" + CliConstants.NAME + " is absent";
|
|
||||||
String actualMessage = "";
|
|
||||||
try {
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"--docker_image", "tf-docker:1.1.0",
|
|
||||||
"--num_workers", "0", "--tensorboard", "--verbose",
|
|
||||||
"--tensorboard_resources", "memory=2G,vcores=2",
|
|
||||||
"--tensorboard_docker_image", "tb_docker_image:001"});
|
|
||||||
} catch (ParseException e) {
|
|
||||||
actualMessage = e.getMessage();
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
assertEquals(expectedErrorMessage, actualMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLaunchCommandPatternReplace() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
|
||||||
|
|
||||||
runJobCli.run(
|
|
||||||
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
|
||||||
"--input_path", "hdfs://input", "--checkpoint_path",
|
|
||||||
"hdfs://output",
|
|
||||||
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
|
|
||||||
"python run-job.py --input=%input_path% " +
|
|
||||||
"--model_dir=%checkpoint_path% " +
|
|
||||||
"--export_dir=%saved_model_path%/savedmodel",
|
|
||||||
"--worker_resources", "memory=2048,vcores=2", "--ps_resources",
|
|
||||||
"memory=4096,vcores=4", "--tensorboard", "true", "--ps_launch_cmd",
|
|
||||||
"python run-ps.py --input=%input_path% " +
|
|
||||||
"--model_dir=%checkpoint_path%/model",
|
|
||||||
"--verbose" });
|
|
||||||
|
|
||||||
assertEquals(
|
|
||||||
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
|
|
||||||
+ "--export_dir=hdfs://output/savedmodel",
|
|
||||||
runJobCli.getRunJobParameters().getWorkerLaunchCmd());
|
|
||||||
assertEquals(
|
|
||||||
"python run-ps.py --input=hdfs://input --model_dir=hdfs://output/model",
|
|
||||||
runJobCli.getRunJobParameters().getPSLaunchCmd());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
package org.apache.hadoop.yarn.submarine.client.cli;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
|
||||||
import org.yaml.snakeyaml.Yaml;
|
import org.yaml.snakeyaml.Yaml;
|
||||||
import org.yaml.snakeyaml.constructor.Constructor;
|
import org.yaml.snakeyaml.constructor.Constructor;
|
||||||
|
@ -33,13 +33,13 @@ public final class YamlConfigTestUtils {
|
||||||
|
|
||||||
private YamlConfigTestUtils() {}
|
private YamlConfigTestUtils() {}
|
||||||
|
|
||||||
static void deleteFile(File file) {
|
public static void deleteFile(File file) {
|
||||||
if (file != null) {
|
if (file != null) {
|
||||||
file.delete();
|
file.delete();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static YamlConfigFile readYamlConfigFile(String filename) {
|
public static YamlConfigFile readYamlConfigFile(String filename) {
|
||||||
Constructor constructor = new Constructor(YamlConfigFile.class);
|
Constructor constructor = new Constructor(YamlConfigFile.class);
|
||||||
constructor.setPropertyUtils(new UnderscoreConverterPropertyUtils());
|
constructor.setPropertyUtils(new UnderscoreConverterPropertyUtils());
|
||||||
Yaml yaml = new Yaml(constructor);
|
Yaml yaml = new Yaml(constructor);
|
||||||
|
@ -49,7 +49,8 @@ public final class YamlConfigTestUtils {
|
||||||
return yaml.loadAs(inputStream, YamlConfigFile.class);
|
return yaml.loadAs(inputStream, YamlConfigFile.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
static File createTempFileWithContents(String filename) throws IOException {
|
public static File createTempFileWithContents(String filename)
|
||||||
|
throws IOException {
|
||||||
InputStream inputStream = YamlConfigTestUtils.class
|
InputStream inputStream = YamlConfigTestUtils.class
|
||||||
.getClassLoader()
|
.getClassLoader()
|
||||||
.getResourceAsStream(filename);
|
.getResourceAsStream(filename);
|
||||||
|
@ -58,7 +59,7 @@ public final class YamlConfigTestUtils {
|
||||||
return targetFile;
|
return targetFile;
|
||||||
}
|
}
|
||||||
|
|
||||||
static File createEmptyTempFile() throws IOException {
|
public static File createEmptyTempFile() throws IOException {
|
||||||
return File.createTempFile("test", ".yaml");
|
return File.createTempFile("test", ".yaml");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||||
|
|
||||||
|
import org.apache.commons.cli.MissingArgumentException;
|
||||||
|
import org.apache.commons.cli.ParseException;
|
||||||
|
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||||
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.ExpectedException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class contains some test methods to test common functionality
|
||||||
|
* (including TF / PyTorch) of the run job Submarine command.
|
||||||
|
*/
|
||||||
|
public class TestRunJobCliParsingCommon {
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void before() {
|
||||||
|
SubmarineLogs.verboseOff();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public ExpectedException expectedException = ExpectedException.none();
|
||||||
|
|
||||||
|
public static MockClientContext getMockClientContext()
|
||||||
|
throws IOException, YarnException {
|
||||||
|
MockClientContext mockClientContext = new MockClientContext();
|
||||||
|
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
||||||
|
when(mockJobSubmitter.submitJob(any(ParametersHolder.class)))
|
||||||
|
.thenReturn(ApplicationId.newInstance(1235L, 1));
|
||||||
|
|
||||||
|
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
||||||
|
SubmarineStorage storage = mock(SubmarineStorage.class);
|
||||||
|
RuntimeFactory rtFactory = mock(RuntimeFactory.class);
|
||||||
|
|
||||||
|
when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
|
||||||
|
when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
|
||||||
|
when(rtFactory.getSubmarineStorage()).thenReturn(storage);
|
||||||
|
|
||||||
|
mockClientContext.setRuntimeFactory(rtFactory);
|
||||||
|
return mockClientContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAbsentFrameworkFallsBackToTensorFlow() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||||
|
"hdfs://output",
|
||||||
|
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||||
|
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||||
|
"true", "--verbose", "--wait_job_finish"});
|
||||||
|
RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
|
||||||
|
assertTrue("Default Framework should be TensorFlow!",
|
||||||
|
runJobParameters instanceof TensorFlowRunJobParameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEmptyFrameworkOption() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
expectedException.expect(MissingArgumentException.class);
|
||||||
|
expectedException.expectMessage("Missing argument for option: framework");
|
||||||
|
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"--framework", "--name", "my-job",
|
||||||
|
"--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||||
|
"hdfs://output",
|
||||||
|
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||||
|
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||||
|
"true", "--verbose", "--wait_job_finish"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidFrameworkOption() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
expectedException.expect(ParseException.class);
|
||||||
|
expectedException.expectMessage("Failed to parse Framework type");
|
||||||
|
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"--framework", "bla", "--name", "my-job",
|
||||||
|
"--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||||
|
"hdfs://output",
|
||||||
|
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||||
|
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||||
|
"true", "--verbose", "--wait_job_finish"});
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,252 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||||
|
|
||||||
|
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
||||||
|
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
|
||||||
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
|
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.ExpectedException;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class contains some test methods to test common YAML parsing
|
||||||
|
* functionality (including TF / PyTorch) of the run job Submarine command.
|
||||||
|
*/
|
||||||
|
public class TestRunJobCliParsingCommonYaml {
|
||||||
|
private static final String DIR_NAME = "runjob-common-yaml";
|
||||||
|
private static final String TF_DIR = "runjob-pytorch-yaml";
|
||||||
|
private File yamlConfig;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void before() {
|
||||||
|
SubmarineLogs.verboseOff();
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void after() {
|
||||||
|
YamlConfigTestUtils.deleteFile(yamlConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void configureResourceTypes() {
|
||||||
|
List<ResourceTypeInfo> resTypes = new ArrayList<>(
|
||||||
|
ResourceUtils.getResourcesTypeInfo());
|
||||||
|
resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
|
||||||
|
ResourceUtils.reinitializeResources(resTypes);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public ExpectedException exception = ExpectedException.none();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
TF_DIR + "/valid-config.yaml");
|
||||||
|
String[] args = new String[] {"--name", "my-job",
|
||||||
|
"--docker_image", "tf-docker:1.1.0",
|
||||||
|
"-f", yamlConfig.getAbsolutePath() };
|
||||||
|
|
||||||
|
exception.expect(YarnException.class);
|
||||||
|
exception.expectMessage("defined both with YAML config and with " +
|
||||||
|
"CLI argument");
|
||||||
|
|
||||||
|
runJobCli.run(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
|
||||||
|
throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
TF_DIR + "/valid-config.yaml");
|
||||||
|
String[] args = new String[] {"--name", "my-job",
|
||||||
|
"--quicklink", "AAA=http://master-0:8321",
|
||||||
|
"--quicklink", "BBB=http://worker-0:1234",
|
||||||
|
"-f", yamlConfig.getAbsolutePath()};
|
||||||
|
|
||||||
|
exception.expect(YarnException.class);
|
||||||
|
exception.expectMessage("defined both with YAML config and with " +
|
||||||
|
"CLI argument");
|
||||||
|
|
||||||
|
runJobCli.run(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFalseValuesForBooleanFields() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/test-false-values.yaml");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
|
||||||
|
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||||
|
TensorFlowRunJobParameters.class,
|
||||||
|
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||||
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
|
assertFalse(jobRunParameters.isDistributeKeytab());
|
||||||
|
assertFalse(jobRunParameters.isWaitJobFinish());
|
||||||
|
assertFalse(tensorFlowParams.isTensorboardEnabled());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWrongIndentation() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/wrong-indentation.yaml");
|
||||||
|
|
||||||
|
exception.expect(YamlParseException.class);
|
||||||
|
exception.expectMessage("Failed to parse YAML config, details:");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWrongFilename() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
exception.expect(YamlParseException.class);
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", "not-existing", "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEmptyFile() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
|
||||||
|
|
||||||
|
exception.expect(YamlParseException.class);
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNotExistingFile() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
exception.expect(YamlParseException.class);
|
||||||
|
exception.expectMessage("file does not exist");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", "blabla", "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWrongPropertyName() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/wrong-property-name.yaml");
|
||||||
|
|
||||||
|
exception.expect(YamlParseException.class);
|
||||||
|
exception.expectMessage("Failed to parse YAML config, details:");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMissingConfigsSection() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/missing-configs.yaml");
|
||||||
|
|
||||||
|
exception.expect(YamlParseException.class);
|
||||||
|
exception.expectMessage("config section should be defined, " +
|
||||||
|
"but it cannot be found");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMissingSectionsShouldParsed() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/some-sections-missing.yaml");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAbsentFramework() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/missing-framework.yaml");
|
||||||
|
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEmptyFramework() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/empty-framework.yaml");
|
||||||
|
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidFramework() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/invalid-framework.yaml");
|
||||||
|
|
||||||
|
exception.expect(YamlParseException.class);
|
||||||
|
exception.expectMessage("framework should is defined, " +
|
||||||
|
"but it has an invalid value");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,192 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import org.apache.commons.cli.ParseException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.ExpectedException;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class contains some test methods to test common CLI parsing
|
||||||
|
* functionality (including TF / PyTorch) of the run job Submarine command.
|
||||||
|
*/
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class TestRunJobCliParsingParameterized {
|
||||||
|
|
||||||
|
private final Framework framework;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void before() {
|
||||||
|
SubmarineLogs.verboseOff();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public ExpectedException expectedException = ExpectedException.none();
|
||||||
|
|
||||||
|
@Parameterized.Parameters
|
||||||
|
public static Collection<Object[]> data() {
|
||||||
|
Collection<Object[]> params = new ArrayList<>();
|
||||||
|
params.add(new Object[]{Framework.TENSORFLOW });
|
||||||
|
params.add(new Object[]{Framework.PYTORCH });
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TestRunJobCliParsingParameterized(Framework framework) {
|
||||||
|
this.framework = framework;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getFrameworkName() {
|
||||||
|
return framework.getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPrintHelp() {
|
||||||
|
MockClientContext mockClientContext = new MockClientContext();
|
||||||
|
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
||||||
|
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
||||||
|
RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
|
||||||
|
mockJobMonitor);
|
||||||
|
runJobCli.printUsages();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNoInputPathOptionSpecified() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\"" +
|
||||||
|
" is absent";
|
||||||
|
String actualMessage = "";
|
||||||
|
try {
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"--framework", getFrameworkName(),
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--checkpoint_path", "hdfs://output",
|
||||||
|
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||||
|
"--worker_resources", "memory=4g,vcores=2", "--verbose",
|
||||||
|
"--wait_job_finish"});
|
||||||
|
} catch (ParseException e) {
|
||||||
|
actualMessage = e.getMessage();
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
assertEquals(expectedErrorMessage, actualMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testJobWithoutName() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
String expectedErrorMessage =
|
||||||
|
"--" + CliConstants.NAME + " is absent";
|
||||||
|
String actualMessage = "";
|
||||||
|
try {
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"--framework", getFrameworkName(),
|
||||||
|
"--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--num_workers", "0", "--verbose"});
|
||||||
|
} catch (ParseException e) {
|
||||||
|
actualMessage = e.getMessage();
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
assertEquals(expectedErrorMessage, actualMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLaunchCommandPatternReplace() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
List<String> parameters = Lists.newArrayList("--framework",
|
||||||
|
getFrameworkName(),
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||||
|
"hdfs://output",
|
||||||
|
"--num_workers", "3",
|
||||||
|
"--worker_launch_cmd", "python run-job.py --input=%input_path% " +
|
||||||
|
"--model_dir=%checkpoint_path% " +
|
||||||
|
"--export_dir=%saved_model_path%/savedmodel",
|
||||||
|
"--worker_resources", "memory=2048,vcores=2");
|
||||||
|
|
||||||
|
if (framework == Framework.TENSORFLOW) {
|
||||||
|
parameters.addAll(Lists.newArrayList(
|
||||||
|
"--ps_resources", "memory=4096,vcores=4",
|
||||||
|
"--ps_launch_cmd", "python run-ps.py --input=%input_path% " +
|
||||||
|
"--model_dir=%checkpoint_path%/model",
|
||||||
|
"--verbose"));
|
||||||
|
}
|
||||||
|
|
||||||
|
runJobCli.run(parameters.toArray(new String[0]));
|
||||||
|
|
||||||
|
RunJobParameters runJobParameters = checkExpectedFrameworkParams(runJobCli);
|
||||||
|
|
||||||
|
if (framework == Framework.TENSORFLOW) {
|
||||||
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) runJobParameters;
|
||||||
|
assertEquals(
|
||||||
|
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
|
||||||
|
+ "--export_dir=hdfs://output/savedmodel",
|
||||||
|
tensorFlowParams.getWorkerLaunchCmd());
|
||||||
|
assertEquals(
|
||||||
|
"python run-ps.py --input=hdfs://input " +
|
||||||
|
"--model_dir=hdfs://output/model",
|
||||||
|
tensorFlowParams.getPSLaunchCmd());
|
||||||
|
} else if (framework == Framework.PYTORCH) {
|
||||||
|
PyTorchRunJobParameters pyTorchParameters =
|
||||||
|
(PyTorchRunJobParameters) runJobParameters;
|
||||||
|
assertEquals(
|
||||||
|
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
|
||||||
|
+ "--export_dir=hdfs://output/savedmodel",
|
||||||
|
pyTorchParameters.getWorkerLaunchCmd());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private RunJobParameters checkExpectedFrameworkParams(RunJobCli runJobCli) {
|
||||||
|
RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
|
||||||
|
|
||||||
|
if (framework == Framework.TENSORFLOW) {
|
||||||
|
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||||
|
TensorFlowRunJobParameters.class,
|
||||||
|
runJobParameters instanceof TensorFlowRunJobParameters);
|
||||||
|
} else if (framework == Framework.PYTORCH) {
|
||||||
|
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||||
|
PyTorchRunJobParameters.class,
|
||||||
|
runJobParameters instanceof PyTorchRunJobParameters);
|
||||||
|
}
|
||||||
|
return runJobParameters;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,209 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
|
||||||
|
|
||||||
|
import org.apache.commons.cli.ParseException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
|
import org.apache.hadoop.yarn.util.resource.Resources;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.ExpectedException;
|
||||||
|
|
||||||
|
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test class that verifies the correctness of PyTorch
|
||||||
|
* CLI configuration parsing.
|
||||||
|
*/
|
||||||
|
public class TestRunJobCliParsingPyTorch {
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void before() {
|
||||||
|
SubmarineLogs.verboseOff();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public ExpectedException expectedException = ExpectedException.none();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicRunJobForSingleNodeTraining() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"--framework", "pytorch",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||||
|
"hdfs://output",
|
||||||
|
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||||
|
"--worker_resources", "memory=4g,vcores=2", "--verbose",
|
||||||
|
"--wait_job_finish" });
|
||||||
|
|
||||||
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
assertTrue(RunJobParameters.class +
|
||||||
|
" must be an instance of " +
|
||||||
|
PyTorchRunJobParameters.class,
|
||||||
|
jobRunParameters instanceof PyTorchRunJobParameters);
|
||||||
|
PyTorchRunJobParameters pyTorchParams =
|
||||||
|
(PyTorchRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
|
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
||||||
|
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
||||||
|
assertEquals(pyTorchParams.getNumWorkers(), 1);
|
||||||
|
assertEquals(pyTorchParams.getWorkerLaunchCmd(),
|
||||||
|
"python run-job.py");
|
||||||
|
assertEquals(Resources.createResource(4096, 2),
|
||||||
|
pyTorchParams.getWorkerResource());
|
||||||
|
assertTrue(SubmarineLogs.isVerbose());
|
||||||
|
assertTrue(jobRunParameters.isWaitJobFinish());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNumPSCannotBeDefined() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
expectedException.expect(ParseException.class);
|
||||||
|
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"--framework", "pytorch",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input",
|
||||||
|
"--checkpoint_path","hdfs://output",
|
||||||
|
"--num_workers", "3",
|
||||||
|
"--worker_launch_cmd",
|
||||||
|
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||||
|
"--num_ps", "2" });
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPSResourcesCannotBeDefined() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
expectedException.expect(ParseException.class);
|
||||||
|
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"--framework", "pytorch",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input",
|
||||||
|
"--checkpoint_path", "hdfs://output",
|
||||||
|
"--num_workers", "3",
|
||||||
|
"--worker_launch_cmd",
|
||||||
|
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||||
|
"--ps_resources", "memory=2048M,vcores=2" });
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPSDockerImageCannotBeDefined() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
expectedException.expect(ParseException.class);
|
||||||
|
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"--framework", "pytorch",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input",
|
||||||
|
"--checkpoint_path", "hdfs://output",
|
||||||
|
"--num_workers", "3",
|
||||||
|
"--worker_launch_cmd",
|
||||||
|
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||||
|
"--ps_docker_image", "psDockerImage" });
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPSLaunchCommandCannotBeDefined() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
expectedException.expect(ParseException.class);
|
||||||
|
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"--framework", "pytorch",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input",
|
||||||
|
"--checkpoint_path", "hdfs://output",
|
||||||
|
"--num_workers", "3",
|
||||||
|
"--worker_launch_cmd",
|
||||||
|
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||||
|
"--ps_launch_cmd", "psLaunchCommand" });
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTensorboardCannotBeDefined() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
expectedException.expect(ParseException.class);
|
||||||
|
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"--framework", "pytorch",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input",
|
||||||
|
"--checkpoint_path", "hdfs://output",
|
||||||
|
"--num_workers", "3",
|
||||||
|
"--worker_launch_cmd",
|
||||||
|
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||||
|
"--tensorboard" });
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTensorboardResourcesCannotBeDefined() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
expectedException.expect(ParseException.class);
|
||||||
|
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"--framework", "pytorch",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input",
|
||||||
|
"--checkpoint_path", "hdfs://output",
|
||||||
|
"--num_workers", "3",
|
||||||
|
"--worker_launch_cmd",
|
||||||
|
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||||
|
"--tensorboard_resources", "memory=2048M,vcores=2" });
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTensorboardDockerImageCannotBeDefined() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
expectedException.expect(ParseException.class);
|
||||||
|
expectedException.expectMessage("cannot be defined for PyTorch jobs");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"--framework", "pytorch",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input",
|
||||||
|
"--checkpoint_path", "hdfs://output",
|
||||||
|
"--num_workers", "3",
|
||||||
|
"--worker_launch_cmd",
|
||||||
|
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||||
|
"--tensorboard_docker_image", "TBDockerImage" });
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,225 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
|
||||||
|
|
||||||
|
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
import static org.junit.Assert.assertNull;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
||||||
|
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
|
||||||
|
import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
|
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.ExpectedException;
|
||||||
|
|
||||||
|
import com.google.common.collect.ImmutableList;
|
||||||
|
import com.google.common.collect.ImmutableMap;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test class that verifies the correctness of PyTorch
|
||||||
|
* YAML configuration parsing.
|
||||||
|
*/
|
||||||
|
public class TestRunJobCliParsingPyTorchYaml {
|
||||||
|
private static final String OVERRIDDEN_PREFIX = "overridden_";
|
||||||
|
private static final String DIR_NAME = "runjob-pytorch-yaml";
|
||||||
|
private File yamlConfig;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void before() {
|
||||||
|
SubmarineLogs.verboseOff();
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void after() {
|
||||||
|
YamlConfigTestUtils.deleteFile(yamlConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void configureResourceTypes() {
|
||||||
|
List<ResourceTypeInfo> resTypes = new ArrayList<>(
|
||||||
|
ResourceUtils.getResourcesTypeInfo());
|
||||||
|
resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
|
||||||
|
ResourceUtils.reinitializeResources(resTypes);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public ExpectedException exception = ExpectedException.none();
|
||||||
|
|
||||||
|
private void verifyBasicConfigValues(RunJobParameters jobRunParameters) {
|
||||||
|
verifyBasicConfigValues(jobRunParameters,
|
||||||
|
ImmutableList.of("env1=env1Value", "env2=env2Value"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void verifyBasicConfigValues(RunJobParameters jobRunParameters,
|
||||||
|
List<String> expectedEnvs) {
|
||||||
|
assertEquals("testInputPath", jobRunParameters.getInputPath());
|
||||||
|
assertEquals("testCheckpointPath", jobRunParameters.getCheckpointPath());
|
||||||
|
assertEquals("testDockerImage", jobRunParameters.getDockerImageName());
|
||||||
|
|
||||||
|
assertNotNull(jobRunParameters.getLocalizations());
|
||||||
|
assertEquals(2, jobRunParameters.getLocalizations().size());
|
||||||
|
|
||||||
|
assertNotNull(jobRunParameters.getQuicklinks());
|
||||||
|
assertEquals(2, jobRunParameters.getQuicklinks().size());
|
||||||
|
|
||||||
|
assertTrue(SubmarineLogs.isVerbose());
|
||||||
|
assertTrue(jobRunParameters.isWaitJobFinish());
|
||||||
|
|
||||||
|
for (String env : expectedEnvs) {
|
||||||
|
assertTrue(String.format(
|
||||||
|
"%s should be in env list of jobRunParameters!", env),
|
||||||
|
jobRunParameters.getEnvars().contains(env));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void verifyWorkerValues(RunJobParameters jobRunParameters,
|
||||||
|
String prefix) {
|
||||||
|
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||||
|
PyTorchRunJobParameters.class,
|
||||||
|
jobRunParameters instanceof PyTorchRunJobParameters);
|
||||||
|
PyTorchRunJobParameters tensorFlowParams =
|
||||||
|
(PyTorchRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
|
assertEquals(3, tensorFlowParams.getNumWorkers());
|
||||||
|
assertEquals(prefix + "testLaunchCmdWorker",
|
||||||
|
tensorFlowParams.getWorkerLaunchCmd());
|
||||||
|
assertEquals(prefix + "testDockerImageWorker",
|
||||||
|
tensorFlowParams.getWorkerDockerImage());
|
||||||
|
assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
|
||||||
|
ImmutableMap.<String, String> builder()
|
||||||
|
.put(ResourceInformation.GPU_URI, "2").build()),
|
||||||
|
tensorFlowParams.getWorkerResource());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void verifySecurityValues(RunJobParameters jobRunParameters) {
|
||||||
|
assertEquals("keytabPath", jobRunParameters.getKeytab());
|
||||||
|
assertEquals("testPrincipal", jobRunParameters.getPrincipal());
|
||||||
|
assertTrue(jobRunParameters.isDistributeKeytab());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testValidYamlParsing() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/valid-config.yaml");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
|
||||||
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
verifyBasicConfigValues(jobRunParameters);
|
||||||
|
verifyWorkerValues(jobRunParameters, "");
|
||||||
|
verifySecurityValues(jobRunParameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRoleOverrides() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
Assert.assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/valid-config-with-overrides.yaml");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
|
||||||
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
verifyBasicConfigValues(jobRunParameters);
|
||||||
|
verifyWorkerValues(jobRunParameters, OVERRIDDEN_PREFIX);
|
||||||
|
verifySecurityValues(jobRunParameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMissingPrincipalUnderSecuritySection() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/security-principal-is-missing.yaml");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
|
||||||
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
verifyBasicConfigValues(jobRunParameters);
|
||||||
|
verifyWorkerValues(jobRunParameters, "");
|
||||||
|
|
||||||
|
//Verify security values
|
||||||
|
assertEquals("keytabPath", jobRunParameters.getKeytab());
|
||||||
|
assertNull("Principal should be null!", jobRunParameters.getPrincipal());
|
||||||
|
assertTrue(jobRunParameters.isDistributeKeytab());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMissingEnvs() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/envs-are-missing.yaml");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
|
||||||
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
verifyBasicConfigValues(jobRunParameters, ImmutableList.of());
|
||||||
|
verifyWorkerValues(jobRunParameters, "");
|
||||||
|
verifySecurityValues(jobRunParameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidConfigPsSectionIsDefined() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
exception.expect(YamlParseException.class);
|
||||||
|
exception.expectMessage("PS section should not be defined " +
|
||||||
|
"when PyTorch is the selected framework");
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/invalid-config-ps-section.yaml");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInvalidConfigTensorboardSectionIsDefined() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
exception.expect(YamlParseException.class);
|
||||||
|
exception.expectMessage("TensorBoard section should not be defined " +
|
||||||
|
"when PyTorch is the selected framework");
|
||||||
|
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
||||||
|
DIR_NAME + "/invalid-config-tensorboard-section.yaml");
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,170 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
|
||||||
|
|
||||||
|
import org.apache.commons.cli.ParseException;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
|
import org.apache.hadoop.yarn.util.resource.Resources;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.ExpectedException;
|
||||||
|
|
||||||
|
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test class that verifies the correctness of TensorFlow
|
||||||
|
* CLI configuration parsing.
|
||||||
|
*/
|
||||||
|
public class TestRunJobCliParsingTensorFlow {
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void before() {
|
||||||
|
SubmarineLogs.verboseOff();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public ExpectedException expectedException = ExpectedException.none();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNoInputPathOptionSpecified() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH +
|
||||||
|
"\" is absent";
|
||||||
|
String actualMessage = "";
|
||||||
|
try {
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"--framework", "tensorflow",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--checkpoint_path", "hdfs://output",
|
||||||
|
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||||
|
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||||
|
"true", "--verbose", "--wait_job_finish"});
|
||||||
|
} catch (ParseException e) {
|
||||||
|
actualMessage = e.getMessage();
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
assertEquals(expectedErrorMessage, actualMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicRunJobForDistributedTraining() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] { "--framework", "tensorflow",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input",
|
||||||
|
"--checkpoint_path", "hdfs://output",
|
||||||
|
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
|
||||||
|
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||||
|
"--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
|
||||||
|
"--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
|
||||||
|
"--principal", "user/_HOST@domain.com", "--distribute_keytab",
|
||||||
|
"--verbose" });
|
||||||
|
|
||||||
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
assertTrue(RunJobParameters.class +
|
||||||
|
" must be an instance of " +
|
||||||
|
TensorFlowRunJobParameters.class,
|
||||||
|
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||||
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
|
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
||||||
|
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
||||||
|
assertEquals(tensorFlowParams.getNumPS(), 2);
|
||||||
|
assertEquals(tensorFlowParams.getPSLaunchCmd(), "python run-ps.py");
|
||||||
|
assertEquals(Resources.createResource(4096, 4),
|
||||||
|
tensorFlowParams.getPsResource());
|
||||||
|
assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
|
||||||
|
"python run-job.py");
|
||||||
|
assertEquals(Resources.createResource(2048, 2),
|
||||||
|
tensorFlowParams.getWorkerResource());
|
||||||
|
assertEquals(jobRunParameters.getDockerImageName(),
|
||||||
|
"tf-docker:1.1.0");
|
||||||
|
assertEquals(jobRunParameters.getKeytab(),
|
||||||
|
"/keytab/path");
|
||||||
|
assertEquals(jobRunParameters.getPrincipal(),
|
||||||
|
"user/_HOST@domain.com");
|
||||||
|
assertTrue(jobRunParameters.isDistributeKeytab());
|
||||||
|
assertTrue(SubmarineLogs.isVerbose());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicRunJobForSingleNodeTraining() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
assertFalse(SubmarineLogs.isVerbose());
|
||||||
|
|
||||||
|
runJobCli.run(
|
||||||
|
new String[] { "--framework", "tensorflow",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--input_path", "hdfs://input", "--checkpoint_path",
|
||||||
|
"hdfs://output",
|
||||||
|
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
|
||||||
|
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
|
||||||
|
"true", "--verbose", "--wait_job_finish" });
|
||||||
|
|
||||||
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
assertTrue(RunJobParameters.class +
|
||||||
|
" must be an instance of " +
|
||||||
|
TensorFlowRunJobParameters.class,
|
||||||
|
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||||
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
|
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
|
||||||
|
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
|
||||||
|
assertEquals(tensorFlowParams.getNumWorkers(), 1);
|
||||||
|
assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
|
||||||
|
"python run-job.py");
|
||||||
|
assertEquals(Resources.createResource(4096, 2),
|
||||||
|
tensorFlowParams.getWorkerResource());
|
||||||
|
assertTrue(SubmarineLogs.isVerbose());
|
||||||
|
assertTrue(jobRunParameters.isWaitJobFinish());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* when only run tensorboard, input_path is not needed
|
||||||
|
* */
|
||||||
|
@Test
|
||||||
|
public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
|
||||||
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
boolean success = true;
|
||||||
|
try {
|
||||||
|
runJobCli.run(
|
||||||
|
new String[]{"--framework", "tensorflow",
|
||||||
|
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
||||||
|
"--num_workers", "0", "--tensorboard", "--verbose",
|
||||||
|
"--tensorboard_resources", "memory=2G,vcores=2",
|
||||||
|
"--tensorboard_docker_image", "tb_docker_image:001"});
|
||||||
|
} catch (ParseException e) {
|
||||||
|
success = false;
|
||||||
|
}
|
||||||
|
assertTrue(success);
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,16 +15,17 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableList;
|
import com.google.common.collect.ImmutableList;
|
||||||
import com.google.common.collect.ImmutableMap;
|
import com.google.common.collect.ImmutableMap;
|
||||||
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
||||||
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
|
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
|
||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
|
||||||
import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
|
import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -39,19 +40,18 @@ import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.apache.hadoop.yarn.submarine.client.cli.TestRunJobCliParsing.getMockClientContext;
|
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.Assert.assertNull;
|
import static org.junit.Assert.assertNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test class that verifies the correctness of YAML configuration parsing.
|
* Test class that verifies the correctness of TF YAML configuration parsing.
|
||||||
*/
|
*/
|
||||||
public class TestRunJobCliParsingYaml {
|
public class TestRunJobCliParsingTensorFlowYaml {
|
||||||
private static final String OVERRIDDEN_PREFIX = "overridden_";
|
private static final String OVERRIDDEN_PREFIX = "overridden_";
|
||||||
private static final String DIR_NAME = "runjobcliparsing";
|
private static final String DIR_NAME = "runjob-tensorflow-yaml";
|
||||||
private File yamlConfig;
|
private File yamlConfig;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
|
@ -104,27 +104,39 @@ public class TestRunJobCliParsingYaml {
|
||||||
|
|
||||||
private void verifyPsValues(RunJobParameters jobRunParameters,
|
private void verifyPsValues(RunJobParameters jobRunParameters,
|
||||||
String prefix) {
|
String prefix) {
|
||||||
assertEquals(4, jobRunParameters.getNumPS());
|
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||||
assertEquals(prefix + "testLaunchCmdPs", jobRunParameters.getPSLaunchCmd());
|
TensorFlowRunJobParameters.class,
|
||||||
|
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||||
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
|
assertEquals(4, tensorFlowParams.getNumPS());
|
||||||
|
assertEquals(prefix + "testLaunchCmdPs", tensorFlowParams.getPSLaunchCmd());
|
||||||
assertEquals(prefix + "testDockerImagePs",
|
assertEquals(prefix + "testDockerImagePs",
|
||||||
jobRunParameters.getPsDockerImage());
|
tensorFlowParams.getPsDockerImage());
|
||||||
assertEquals(ResourceTypesTestHelper.newResource(20500L, 34,
|
assertEquals(ResourceTypesTestHelper.newResource(20500L, 34,
|
||||||
ImmutableMap.<String, String> builder()
|
ImmutableMap.<String, String> builder()
|
||||||
.put(ResourceInformation.GPU_URI, "4").build()),
|
.put(ResourceInformation.GPU_URI, "4").build()),
|
||||||
jobRunParameters.getPsResource());
|
tensorFlowParams.getPsResource());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void verifyWorkerValues(RunJobParameters jobRunParameters,
|
private void verifyWorkerValues(RunJobParameters jobRunParameters,
|
||||||
String prefix) {
|
String prefix) {
|
||||||
assertEquals(3, jobRunParameters.getNumWorkers());
|
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||||
|
TensorFlowRunJobParameters.class,
|
||||||
|
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||||
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
|
assertEquals(3, tensorFlowParams.getNumWorkers());
|
||||||
assertEquals(prefix + "testLaunchCmdWorker",
|
assertEquals(prefix + "testLaunchCmdWorker",
|
||||||
jobRunParameters.getWorkerLaunchCmd());
|
tensorFlowParams.getWorkerLaunchCmd());
|
||||||
assertEquals(prefix + "testDockerImageWorker",
|
assertEquals(prefix + "testDockerImageWorker",
|
||||||
jobRunParameters.getWorkerDockerImage());
|
tensorFlowParams.getWorkerDockerImage());
|
||||||
assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
|
assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
|
||||||
ImmutableMap.<String, String> builder()
|
ImmutableMap.<String, String> builder()
|
||||||
.put(ResourceInformation.GPU_URI, "2").build()),
|
.put(ResourceInformation.GPU_URI, "2").build()),
|
||||||
jobRunParameters.getWorkerResource());
|
tensorFlowParams.getWorkerResource());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void verifySecurityValues(RunJobParameters jobRunParameters) {
|
private void verifySecurityValues(RunJobParameters jobRunParameters) {
|
||||||
|
@ -134,13 +146,19 @@ public class TestRunJobCliParsingYaml {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void verifyTensorboardValues(RunJobParameters jobRunParameters) {
|
private void verifyTensorboardValues(RunJobParameters jobRunParameters) {
|
||||||
assertTrue(jobRunParameters.isTensorboardEnabled());
|
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||||
|
TensorFlowRunJobParameters.class,
|
||||||
|
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||||
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
|
assertTrue(tensorFlowParams.isTensorboardEnabled());
|
||||||
assertEquals("tensorboardDockerImage",
|
assertEquals("tensorboardDockerImage",
|
||||||
jobRunParameters.getTensorboardDockerImage());
|
tensorFlowParams.getTensorboardDockerImage());
|
||||||
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
|
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
|
||||||
ImmutableMap.<String, String> builder()
|
ImmutableMap.<String, String> builder()
|
||||||
.put(ResourceInformation.GPU_URI, "3").build()),
|
.put(ResourceInformation.GPU_URI, "3").build()),
|
||||||
jobRunParameters.getTensorboardResource());
|
tensorFlowParams.getTensorboardResource());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -161,44 +179,6 @@ public class TestRunJobCliParsingYaml {
|
||||||
verifyTensorboardValues(jobRunParameters);
|
verifyTensorboardValues(jobRunParameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
|
||||||
|
|
||||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
|
||||||
DIR_NAME + "/valid-config.yaml");
|
|
||||||
String[] args = new String[] {"--name", "my-job",
|
|
||||||
"--docker_image", "tf-docker:1.1.0",
|
|
||||||
"-f", yamlConfig.getAbsolutePath() };
|
|
||||||
|
|
||||||
exception.expect(YarnException.class);
|
|
||||||
exception.expectMessage("defined both with YAML config and with " +
|
|
||||||
"CLI argument");
|
|
||||||
|
|
||||||
runJobCli.run(args);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
|
|
||||||
throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
|
||||||
|
|
||||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
|
||||||
DIR_NAME + "/valid-config.yaml");
|
|
||||||
String[] args = new String[] {"--name", "my-job",
|
|
||||||
"--quicklink", "AAA=http://master-0:8321",
|
|
||||||
"--quicklink", "BBB=http://worker-0:1234",
|
|
||||||
"-f", yamlConfig.getAbsolutePath()};
|
|
||||||
|
|
||||||
exception.expect(YarnException.class);
|
|
||||||
exception.expectMessage("defined both with YAML config and with " +
|
|
||||||
"CLI argument");
|
|
||||||
|
|
||||||
runJobCli.run(args);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRoleOverrides() throws Exception {
|
public void testRoleOverrides() throws Exception {
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
@ -217,104 +197,6 @@ public class TestRunJobCliParsingYaml {
|
||||||
verifyTensorboardValues(jobRunParameters);
|
verifyTensorboardValues(jobRunParameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testFalseValuesForBooleanFields() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
|
||||||
|
|
||||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
|
||||||
DIR_NAME + "/test-false-values.yaml");
|
|
||||||
runJobCli.run(
|
|
||||||
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
|
||||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
|
||||||
|
|
||||||
assertFalse(jobRunParameters.isDistributeKeytab());
|
|
||||||
assertFalse(jobRunParameters.isWaitJobFinish());
|
|
||||||
assertFalse(jobRunParameters.isTensorboardEnabled());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testWrongIndentation() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
|
||||||
|
|
||||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
|
||||||
DIR_NAME + "/wrong-indentation.yaml");
|
|
||||||
|
|
||||||
exception.expect(YamlParseException.class);
|
|
||||||
exception.expectMessage("Failed to parse YAML config, details:");
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testWrongFilename() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
Assert.assertFalse(SubmarineLogs.isVerbose());
|
|
||||||
|
|
||||||
exception.expect(YamlParseException.class);
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"-f", "not-existing", "--verbose"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testEmptyFile() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
|
|
||||||
yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
|
|
||||||
|
|
||||||
exception.expect(YamlParseException.class);
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNotExistingFile() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
|
|
||||||
exception.expect(YamlParseException.class);
|
|
||||||
exception.expectMessage("file does not exist");
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"-f", "blabla", "--verbose"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testWrongPropertyName() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
|
|
||||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
|
||||||
DIR_NAME + "/wrong-property-name.yaml");
|
|
||||||
|
|
||||||
exception.expect(YamlParseException.class);
|
|
||||||
exception.expectMessage("Failed to parse YAML config, details:");
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testMissingConfigsSection() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
|
|
||||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
|
||||||
DIR_NAME + "/missing-configs.yaml");
|
|
||||||
|
|
||||||
exception.expect(YamlParseException.class);
|
|
||||||
exception.expectMessage("config section should be defined, " +
|
|
||||||
"but it cannot be found");
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testMissingSectionsShouldParsed() throws Exception {
|
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
|
||||||
|
|
||||||
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
|
|
||||||
DIR_NAME + "/some-sections-missing.yaml");
|
|
||||||
runJobCli.run(
|
|
||||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMissingPrincipalUnderSecuritySection() throws Exception {
|
public void testMissingPrincipalUnderSecuritySection() throws Exception {
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
|
@ -346,18 +228,22 @@ public class TestRunJobCliParsingYaml {
|
||||||
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
|
||||||
|
|
||||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
|
||||||
verifyBasicConfigValues(jobRunParameters);
|
verifyBasicConfigValues(jobRunParameters);
|
||||||
verifyPsValues(jobRunParameters, "");
|
verifyPsValues(jobRunParameters, "");
|
||||||
verifyWorkerValues(jobRunParameters, "");
|
verifyWorkerValues(jobRunParameters, "");
|
||||||
verifySecurityValues(jobRunParameters);
|
verifySecurityValues(jobRunParameters);
|
||||||
|
|
||||||
assertTrue(jobRunParameters.isTensorboardEnabled());
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
|
assertTrue(tensorFlowParams.isTensorboardEnabled());
|
||||||
assertNull("tensorboardDockerImage should be null!",
|
assertNull("tensorboardDockerImage should be null!",
|
||||||
jobRunParameters.getTensorboardDockerImage());
|
tensorFlowParams.getTensorboardDockerImage());
|
||||||
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
|
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
|
||||||
ImmutableMap.<String, String> builder()
|
ImmutableMap.<String, String> builder()
|
||||||
.put(ResourceInformation.GPU_URI, "3").build()),
|
.put(ResourceInformation.GPU_URI, "3").build()),
|
||||||
jobRunParameters.getTensorboardResource());
|
tensorFlowParams.getTensorboardResource());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
|
@ -15,7 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.submarine.client.cli;
|
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
|
||||||
|
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
|
||||||
|
@ -42,14 +42,9 @@ import static org.junit.Assert.assertTrue;
|
||||||
* Please note that this class just tests YAML parsing,
|
* Please note that this class just tests YAML parsing,
|
||||||
* but only in an isolated fashion.
|
* but only in an isolated fashion.
|
||||||
*/
|
*/
|
||||||
public class TestRunJobCliParsingYamlStandalone {
|
public class TestRunJobCliParsingTensorFlowYamlStandalone {
|
||||||
private static final String OVERRIDDEN_PREFIX = "overridden_";
|
private static final String OVERRIDDEN_PREFIX = "overridden_";
|
||||||
private static final String DIR_NAME = "runjobcliparsing";
|
private static final String DIR_NAME = "runjob-tensorflow-yaml";
|
||||||
|
|
||||||
@Before
|
|
||||||
public void before() {
|
|
||||||
SubmarineLogs.verboseOff();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void verifyBasicConfigValues(YamlConfigFile yamlConfigFile) {
|
private void verifyBasicConfigValues(YamlConfigFile yamlConfigFile) {
|
||||||
assertNotNull("Spec file should not be null!", yamlConfigFile);
|
assertNotNull("Spec file should not be null!", yamlConfigFile);
|
||||||
|
@ -169,6 +164,11 @@ public class TestRunJobCliParsingYamlStandalone {
|
||||||
assertEquals("memory=21000M,vcores=37,gpu=3", tensorBoard.getResources());
|
assertEquals("memory=21000M,vcores=37,gpu=3", tensorBoard.getResources());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void before() {
|
||||||
|
SubmarineLogs.verboseOff();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLaunchCommandYaml() {
|
public void testLaunchCommandYaml() {
|
||||||
YamlConfigFile yamlConfigFile = readYamlConfigFile(DIR_NAME +
|
YamlConfigFile yamlConfigFile = readYamlConfigFile(DIR_NAME +
|
||||||
|
@ -201,5 +201,4 @@ public class TestRunJobCliParsingYamlStandalone {
|
||||||
assertRoleConfigOverrides(roles.getWorker(), OVERRIDDEN_PREFIX, "Worker");
|
assertRoleConfigOverrides(roles.getWorker(), OVERRIDDEN_PREFIX, "Worker");
|
||||||
assertRoleConfigOverrides(roles.getPs(), OVERRIDDEN_PREFIX, "Ps");
|
assertRoleConfigOverrides(roles.getPs(), OVERRIDDEN_PREFIX, "Ps");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
spec:
|
||||||
|
name: testJobName
|
||||||
|
job_type: testJobType
|
||||||
|
framework:
|
||||||
|
|
||||||
|
configs:
|
||||||
|
input_path: testInputPath
|
||||||
|
checkpoint_path: testCheckpointPath
|
||||||
|
saved_model_path: testSavedModelPath
|
||||||
|
docker_image: testDockerImage
|
||||||
|
wait_job_finish: true
|
||||||
|
envs:
|
||||||
|
env1: 'env1Value'
|
||||||
|
env2: 'env2Value'
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/local-filename1:rw
|
||||||
|
- nfs://remote-file2:/local-filename2:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/etc/passwd:rw
|
||||||
|
- /etc/hosts:/etc/hosts:rw
|
||||||
|
quicklinks:
|
||||||
|
- Notebook_UI=https://master-0:7070
|
||||||
|
- Notebook_UI2=https://master-0:7071
|
||||||
|
|
||||||
|
scheduling:
|
||||||
|
queue: queue1
|
||||||
|
|
||||||
|
roles:
|
||||||
|
worker:
|
||||||
|
resources: memory=20480M,vcores=32,gpu=2
|
||||||
|
replicas: 3
|
||||||
|
launch_cmd: testLaunchCmdWorker
|
||||||
|
docker_image: testDockerImageWorker
|
||||||
|
ps:
|
||||||
|
resources: memory=20500M,vcores=34,gpu=4
|
||||||
|
replicas: 4
|
||||||
|
launch_cmd: testLaunchCmdPs
|
||||||
|
docker_image: testDockerImagePs
|
||||||
|
|
||||||
|
security:
|
||||||
|
keytab: keytabPath
|
||||||
|
principal: testPrincipal
|
||||||
|
distribute_keytab: true
|
||||||
|
|
||||||
|
tensorBoard:
|
||||||
|
resources: memory=21000M,vcores=37,gpu=3
|
||||||
|
docker_image: tensorboardDockerImage
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
spec:
|
||||||
|
name: testJobName
|
||||||
|
job_type: testJobType
|
||||||
|
framework: bla
|
||||||
|
|
||||||
|
configs:
|
||||||
|
input_path: testInputPath
|
||||||
|
checkpoint_path: testCheckpointPath
|
||||||
|
saved_model_path: testSavedModelPath
|
||||||
|
docker_image: testDockerImage
|
||||||
|
wait_job_finish: true
|
||||||
|
envs:
|
||||||
|
env1: 'env1Value'
|
||||||
|
env2: 'env2Value'
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/local-filename1:rw
|
||||||
|
- nfs://remote-file2:/local-filename2:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/etc/passwd:rw
|
||||||
|
- /etc/hosts:/etc/hosts:rw
|
||||||
|
quicklinks:
|
||||||
|
- Notebook_UI=https://master-0:7070
|
||||||
|
- Notebook_UI2=https://master-0:7071
|
||||||
|
|
||||||
|
scheduling:
|
||||||
|
queue: queue1
|
||||||
|
|
||||||
|
roles:
|
||||||
|
worker:
|
||||||
|
resources: memory=20480M,vcores=32,gpu=2
|
||||||
|
replicas: 3
|
||||||
|
launch_cmd: testLaunchCmdWorker
|
||||||
|
docker_image: testDockerImageWorker
|
||||||
|
ps:
|
||||||
|
resources: memory=20500M,vcores=34,gpu=4
|
||||||
|
replicas: 4
|
||||||
|
launch_cmd: testLaunchCmdPs
|
||||||
|
docker_image: testDockerImagePs
|
||||||
|
|
||||||
|
security:
|
||||||
|
keytab: keytabPath
|
||||||
|
principal: testPrincipal
|
||||||
|
distribute_keytab: true
|
||||||
|
|
||||||
|
tensorBoard:
|
||||||
|
resources: memory=21000M,vcores=37,gpu=3
|
||||||
|
docker_image: tensorboardDockerImage
|
|
@ -17,6 +17,7 @@
|
||||||
spec:
|
spec:
|
||||||
name: testJobName
|
name: testJobName
|
||||||
job_type: testJobType
|
job_type: testJobType
|
||||||
|
framework: tensorflow
|
||||||
|
|
||||||
configs:
|
configs:
|
||||||
input_path: testInputPath
|
input_path: testInputPath
|
|
@ -17,6 +17,7 @@
|
||||||
spec:
|
spec:
|
||||||
name: testJobName
|
name: testJobName
|
||||||
job_type: testJobType
|
job_type: testJobType
|
||||||
|
framework: tensorflow
|
||||||
|
|
||||||
configs:
|
configs:
|
||||||
input_path: testInputPath
|
input_path: testInputPath
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
spec:
|
||||||
|
name: testJobName
|
||||||
|
job_type: testJobType
|
||||||
|
framework: pytorch
|
||||||
|
|
||||||
|
configs:
|
||||||
|
input_path: testInputPath
|
||||||
|
checkpoint_path: testCheckpointPath
|
||||||
|
saved_model_path: testSavedModelPath
|
||||||
|
docker_image: testDockerImage
|
||||||
|
wait_job_finish: true
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/local-filename1:rw
|
||||||
|
- nfs://remote-file2:/local-filename2:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/etc/passwd:rw
|
||||||
|
- /etc/hosts:/etc/hosts:rw
|
||||||
|
quicklinks:
|
||||||
|
- Notebook_UI=https://master-0:7070
|
||||||
|
- Notebook_UI2=https://master-0:7071
|
||||||
|
|
||||||
|
scheduling:
|
||||||
|
queue: queue1
|
||||||
|
|
||||||
|
roles:
|
||||||
|
worker:
|
||||||
|
resources: memory=20480M,vcores=32,gpu=2
|
||||||
|
replicas: 3
|
||||||
|
launch_cmd: testLaunchCmdWorker
|
||||||
|
docker_image: testDockerImageWorker
|
||||||
|
|
||||||
|
security:
|
||||||
|
keytab: keytabPath
|
||||||
|
principal: testPrincipal
|
||||||
|
distribute_keytab: true
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
spec:
|
||||||
|
name: testJobName
|
||||||
|
job_type: testJobType
|
||||||
|
framework: pytorch
|
||||||
|
|
||||||
|
configs:
|
||||||
|
input_path: testInputPath
|
||||||
|
checkpoint_path: testCheckpointPath
|
||||||
|
saved_model_path: testSavedModelPath
|
||||||
|
docker_image: testDockerImage
|
||||||
|
wait_job_finish: true
|
||||||
|
envs:
|
||||||
|
env1: 'env1Value'
|
||||||
|
env2: 'env2Value'
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/local-filename1:rw
|
||||||
|
- nfs://remote-file2:/local-filename2:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/etc/passwd:rw
|
||||||
|
- /etc/hosts:/etc/hosts:rw
|
||||||
|
quicklinks:
|
||||||
|
- Notebook_UI=https://master-0:7070
|
||||||
|
- Notebook_UI2=https://master-0:7071
|
||||||
|
|
||||||
|
scheduling:
|
||||||
|
queue: queue1
|
||||||
|
|
||||||
|
roles:
|
||||||
|
worker:
|
||||||
|
resources: memory=20480M,vcores=32,gpu=2
|
||||||
|
replicas: 3
|
||||||
|
launch_cmd: testLaunchCmdWorker
|
||||||
|
docker_image: testDockerImageWorker
|
||||||
|
ps:
|
||||||
|
docker_image: testPSDockerImage
|
||||||
|
|
||||||
|
security:
|
||||||
|
keytab: keytabPath
|
||||||
|
principal: testPrincipal
|
||||||
|
distribute_keytab: true
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
spec:
|
||||||
|
name: testJobName
|
||||||
|
job_type: testJobType
|
||||||
|
framework: pytorch
|
||||||
|
|
||||||
|
configs:
|
||||||
|
input_path: testInputPath
|
||||||
|
checkpoint_path: testCheckpointPath
|
||||||
|
saved_model_path: testSavedModelPath
|
||||||
|
docker_image: testDockerImage
|
||||||
|
wait_job_finish: true
|
||||||
|
envs:
|
||||||
|
env1: 'env1Value'
|
||||||
|
env2: 'env2Value'
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/local-filename1:rw
|
||||||
|
- nfs://remote-file2:/local-filename2:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/etc/passwd:rw
|
||||||
|
- /etc/hosts:/etc/hosts:rw
|
||||||
|
quicklinks:
|
||||||
|
- Notebook_UI=https://master-0:7070
|
||||||
|
- Notebook_UI2=https://master-0:7071
|
||||||
|
|
||||||
|
scheduling:
|
||||||
|
queue: queue1
|
||||||
|
|
||||||
|
roles:
|
||||||
|
worker:
|
||||||
|
resources: memory=20480M,vcores=32,gpu=2
|
||||||
|
replicas: 3
|
||||||
|
launch_cmd: testLaunchCmdWorker
|
||||||
|
docker_image: testDockerImageWorker
|
||||||
|
|
||||||
|
security:
|
||||||
|
keytab: keytabPath
|
||||||
|
principal: testPrincipal
|
||||||
|
distribute_keytab: true
|
||||||
|
|
||||||
|
tensorBoard:
|
||||||
|
docker_image: tensorboardDockerImage
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
spec:
|
||||||
|
name: testJobName
|
||||||
|
job_type: testJobType
|
||||||
|
framework: pytorch
|
||||||
|
|
||||||
|
configs:
|
||||||
|
input_path: testInputPath
|
||||||
|
checkpoint_path: testCheckpointPath
|
||||||
|
saved_model_path: testSavedModelPath
|
||||||
|
docker_image: testDockerImage
|
||||||
|
wait_job_finish: true
|
||||||
|
envs:
|
||||||
|
env1: 'env1Value'
|
||||||
|
env2: 'env2Value'
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/local-filename1:rw
|
||||||
|
- nfs://remote-file2:/local-filename2:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/etc/passwd:rw
|
||||||
|
- /etc/hosts:/etc/hosts:rw
|
||||||
|
quicklinks:
|
||||||
|
- Notebook_UI=https://master-0:7070
|
||||||
|
- Notebook_UI2=https://master-0:7071
|
||||||
|
|
||||||
|
scheduling:
|
||||||
|
queue: queue1
|
||||||
|
|
||||||
|
roles:
|
||||||
|
worker:
|
||||||
|
resources: memory=20480M,vcores=32,gpu=2
|
||||||
|
replicas: 3
|
||||||
|
launch_cmd: testLaunchCmdWorker
|
||||||
|
docker_image: testDockerImageWorker
|
||||||
|
|
||||||
|
security:
|
||||||
|
keytab: keytabPath
|
||||||
|
distribute_keytab: true
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
spec:
|
||||||
|
name: testJobName
|
||||||
|
job_type: testJobType
|
||||||
|
framework: pytorch
|
||||||
|
|
||||||
|
configs:
|
||||||
|
input_path: testInputPath
|
||||||
|
checkpoint_path: testCheckpointPath
|
||||||
|
saved_model_path: testSavedModelPath
|
||||||
|
docker_image: testDockerImage
|
||||||
|
wait_job_finish: true
|
||||||
|
envs:
|
||||||
|
env1: 'env1Value'
|
||||||
|
env2: 'env2Value'
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/local-filename1:rw
|
||||||
|
- nfs://remote-file2:/local-filename2:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/etc/passwd:rw
|
||||||
|
- /etc/hosts:/etc/hosts:rw
|
||||||
|
quicklinks:
|
||||||
|
- Notebook_UI=https://master-0:7070
|
||||||
|
- Notebook_UI2=https://master-0:7071
|
||||||
|
|
||||||
|
scheduling:
|
||||||
|
queue: queue1
|
||||||
|
|
||||||
|
roles:
|
||||||
|
worker:
|
||||||
|
resources: memory=20480M,vcores=32,gpu=2
|
||||||
|
replicas: 3
|
||||||
|
launch_cmd: overridden_testLaunchCmdWorker
|
||||||
|
docker_image: overridden_testDockerImageWorker
|
||||||
|
envs:
|
||||||
|
env1: 'overridden_env1Worker'
|
||||||
|
env2: 'overridden_env2Worker'
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/overridden_local-filename1Worker:rw
|
||||||
|
- nfs://remote-file2:/overridden_local-filename2Worker:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/overridden_Worker
|
||||||
|
- /etc/hosts:/overridden_Worker
|
||||||
|
|
||||||
|
security:
|
||||||
|
keytab: keytabPath
|
||||||
|
principal: testPrincipal
|
||||||
|
distribute_keytab: true
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
spec:
|
||||||
|
name: testJobName
|
||||||
|
job_type: testJobType
|
||||||
|
framework: pytorch
|
||||||
|
|
||||||
|
configs:
|
||||||
|
input_path: testInputPath
|
||||||
|
checkpoint_path: testCheckpointPath
|
||||||
|
saved_model_path: testSavedModelPath
|
||||||
|
docker_image: testDockerImage
|
||||||
|
wait_job_finish: true
|
||||||
|
envs:
|
||||||
|
env1: 'env1Value'
|
||||||
|
env2: 'env2Value'
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/local-filename1:rw
|
||||||
|
- nfs://remote-file2:/local-filename2:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/etc/passwd:rw
|
||||||
|
- /etc/hosts:/etc/hosts:rw
|
||||||
|
quicklinks:
|
||||||
|
- Notebook_UI=https://master-0:7070
|
||||||
|
- Notebook_UI2=https://master-0:7071
|
||||||
|
|
||||||
|
scheduling:
|
||||||
|
queue: queue1
|
||||||
|
|
||||||
|
roles:
|
||||||
|
worker:
|
||||||
|
resources: memory=20480M,vcores=32,gpu=2
|
||||||
|
replicas: 3
|
||||||
|
launch_cmd: testLaunchCmdWorker
|
||||||
|
docker_image: testDockerImageWorker
|
||||||
|
|
||||||
|
security:
|
||||||
|
keytab: keytabPath
|
||||||
|
principal: testPrincipal
|
||||||
|
distribute_keytab: true
|
|
@ -17,6 +17,7 @@
|
||||||
spec:
|
spec:
|
||||||
name: testJobName
|
name: testJobName
|
||||||
job_type: testJobType
|
job_type: testJobType
|
||||||
|
framework: tensorflow
|
||||||
|
|
||||||
configs:
|
configs:
|
||||||
input_path: testInputPath
|
input_path: testInputPath
|
|
@ -17,6 +17,7 @@
|
||||||
spec:
|
spec:
|
||||||
name: testJobName
|
name: testJobName
|
||||||
job_type: testJobType
|
job_type: testJobType
|
||||||
|
framework: tensorflow
|
||||||
|
|
||||||
configs:
|
configs:
|
||||||
input_path: testInputPath
|
input_path: testInputPath
|
|
@ -17,6 +17,7 @@
|
||||||
spec:
|
spec:
|
||||||
name: testJobName
|
name: testJobName
|
||||||
job_type: testJobType
|
job_type: testJobType
|
||||||
|
framework: tensorflow
|
||||||
|
|
||||||
configs:
|
configs:
|
||||||
input_path: testInputPath
|
input_path: testInputPath
|
|
@ -17,6 +17,7 @@
|
||||||
spec:
|
spec:
|
||||||
name: testJobName
|
name: testJobName
|
||||||
job_type: testJobType
|
job_type: testJobType
|
||||||
|
framework: tensorflow
|
||||||
|
|
||||||
configs:
|
configs:
|
||||||
input_path: testInputPath
|
input_path: testInputPath
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
spec:
|
||||||
|
name: testJobName
|
||||||
|
job_type: testJobType
|
||||||
|
framework: tensorflow
|
||||||
|
|
||||||
|
configs:
|
||||||
|
input_path: testInputPath
|
||||||
|
checkpoint_path: testCheckpointPath
|
||||||
|
saved_model_path: testSavedModelPath
|
||||||
|
docker_image: testDockerImage
|
||||||
|
wait_job_finish: true
|
||||||
|
envs:
|
||||||
|
env1: 'env1Value'
|
||||||
|
env2: 'env2Value'
|
||||||
|
localizations:
|
||||||
|
- hdfs://remote-file1:/local-filename1:rw
|
||||||
|
- nfs://remote-file2:/local-filename2:rw
|
||||||
|
mounts:
|
||||||
|
- /etc/passwd:/etc/passwd:rw
|
||||||
|
- /etc/hosts:/etc/hosts:rw
|
||||||
|
quicklinks:
|
||||||
|
- Notebook_UI=https://master-0:7070
|
||||||
|
- Notebook_UI2=https://master-0:7071
|
||||||
|
|
||||||
|
scheduling:
|
||||||
|
queue: queue1
|
||||||
|
|
||||||
|
roles:
|
||||||
|
worker:
|
||||||
|
resources: memory=20480M,vcores=32,gpu=2
|
||||||
|
replicas: 3
|
||||||
|
launch_cmd: testLaunchCmdWorker
|
||||||
|
docker_image: testDockerImageWorker
|
||||||
|
ps:
|
||||||
|
resources: memory=20500M,vcores=34,gpu=4
|
||||||
|
replicas: 4
|
||||||
|
launch_cmd: testLaunchCmdPs
|
||||||
|
docker_image: testDockerImagePs
|
||||||
|
|
||||||
|
security:
|
||||||
|
keytab: keytabPath
|
||||||
|
principal: testPrincipal
|
||||||
|
distribute_keytab: true
|
||||||
|
|
||||||
|
tensorBoard:
|
||||||
|
resources: memory=21000M,vcores=37,gpu=3
|
||||||
|
docker_image: tensorboardDockerImage
|
|
@ -22,7 +22,9 @@ import org.apache.commons.logging.LogFactory;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -45,14 +47,24 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ApplicationId submitJob(RunJobParameters parameters)
|
public ApplicationId submitJob(ParametersHolder parameters)
|
||||||
throws IOException, YarnException {
|
throws IOException {
|
||||||
|
if (parameters.getFramework() == Framework.PYTORCH) {
|
||||||
|
// we need to throw an exception, as ParametersHolder's parameters field
|
||||||
|
// could not be casted to TensorFlowRunJobParameters.
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"Support \"–-framework\" option for PyTorch in Tony is coming. " +
|
||||||
|
"Please check the documentation about how to submit a " +
|
||||||
|
"PyTorch job with TonY runtime.");
|
||||||
|
}
|
||||||
|
|
||||||
LOG.info("Starting Tony runtime..");
|
LOG.info("Starting Tony runtime..");
|
||||||
|
|
||||||
File tonyFinalConfPath = File.createTempFile("temp",
|
File tonyFinalConfPath = File.createTempFile("temp",
|
||||||
Constants.TONY_FINAL_XML);
|
Constants.TONY_FINAL_XML);
|
||||||
// Write user's overridden conf to an xml to be localized.
|
// Write user's overridden conf to an xml to be localized.
|
||||||
Configuration tonyConf = TonyUtils.tonyConfFromClientContext(parameters);
|
Configuration tonyConf = TonyUtils.tonyConfFromClientContext(
|
||||||
|
(TensorFlowRunJobParameters) parameters.getParameters());
|
||||||
try (OutputStream os = new FileOutputStream(tonyFinalConfPath)) {
|
try (OutputStream os = new FileOutputStream(tonyFinalConfPath)) {
|
||||||
tonyConf.writeXml(os);
|
tonyConf.writeXml(os);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
|
@ -68,7 +80,7 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
|
||||||
LOG.error("Failed to init TonyClient: ", e);
|
LOG.error("Failed to init TonyClient: ", e);
|
||||||
}
|
}
|
||||||
Thread clientThread = new Thread(tonyClient::start);
|
Thread clientThread = new Thread(tonyClient::start);
|
||||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
java.lang.Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||||
try {
|
try {
|
||||||
tonyClient.forceKillApplication();
|
tonyClient.forceKillApplication();
|
||||||
} catch (YarnException | IOException e) {
|
} catch (YarnException | IOException e) {
|
||||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
import org.apache.hadoop.yarn.api.records.ResourceInformation;
|
||||||
import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
|
import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -35,7 +35,7 @@ public final class TonyUtils {
|
||||||
private static final Log LOG = LogFactory.getLog(TonyUtils.class);
|
private static final Log LOG = LogFactory.getLog(TonyUtils.class);
|
||||||
|
|
||||||
public static Configuration tonyConfFromClientContext(
|
public static Configuration tonyConfFromClientContext(
|
||||||
RunJobParameters parameters) {
|
TensorFlowRunJobParameters parameters) {
|
||||||
Configuration tonyConf = new Configuration();
|
Configuration tonyConf = new Configuration();
|
||||||
tonyConf.setInt(
|
tonyConf.setInt(
|
||||||
TonyConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME),
|
TonyConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME),
|
||||||
|
|
|
@ -147,6 +147,7 @@ CLASSPATH=$(hadoop classpath --glob): \
|
||||||
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
||||||
|
|
||||||
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
||||||
|
--framework tensorflow \
|
||||||
--num_workers 2 \
|
--num_workers 2 \
|
||||||
--worker_resources memory=3G,vcores=2 \
|
--worker_resources memory=3G,vcores=2 \
|
||||||
--num_ps 2 \
|
--num_ps 2 \
|
||||||
|
@ -183,6 +184,7 @@ CLASSPATH=$(hadoop classpath --glob): \
|
||||||
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
||||||
|
|
||||||
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
||||||
|
--framework tensorflow \
|
||||||
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
|
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
|
||||||
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
|
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
|
||||||
--worker_resources memory=3G,vcores=2 \
|
--worker_resources memory=3G,vcores=2 \
|
||||||
|
@ -245,6 +247,7 @@ CLASSPATH=$(hadoop classpath --glob): \
|
||||||
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
||||||
|
|
||||||
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
||||||
|
--framework tensorflow \
|
||||||
--num_workers 2 \
|
--num_workers 2 \
|
||||||
--worker_resources memory=3G,vcores=2 \
|
--worker_resources memory=3G,vcores=2 \
|
||||||
--num_ps 2 \
|
--num_ps 2 \
|
||||||
|
@ -281,6 +284,7 @@ CLASSPATH=$(hadoop classpath --glob): \
|
||||||
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
|
||||||
|
|
||||||
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
|
||||||
|
--framework tensorflow \
|
||||||
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
|
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
|
||||||
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
|
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
|
||||||
--worker_resources memory=3G,vcores=2 \
|
--worker_resources memory=3G,vcores=2 \
|
||||||
|
|
|
@ -16,8 +16,10 @@ import com.linkedin.tony.TonyConfigurationKeys;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
|
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
|
||||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
|
||||||
|
@ -31,6 +33,7 @@ import org.junit.Test;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
@ -59,7 +62,8 @@ public class TestTonyUtils {
|
||||||
throws IOException, YarnException {
|
throws IOException, YarnException {
|
||||||
MockClientContext mockClientContext = new MockClientContext();
|
MockClientContext mockClientContext = new MockClientContext();
|
||||||
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
|
||||||
when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
|
when(mockJobSubmitter.submitJob(
|
||||||
|
any(ParametersHolder.class))).thenReturn(
|
||||||
ApplicationId.newInstance(1234L, 1));
|
ApplicationId.newInstance(1234L, 1));
|
||||||
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
JobMonitor mockJobMonitor = mock(JobMonitor.class);
|
||||||
SubmarineStorage storage = mock(SubmarineStorage.class);
|
SubmarineStorage storage = mock(SubmarineStorage.class);
|
||||||
|
@ -82,20 +86,28 @@ public class TestTonyUtils {
|
||||||
public void testTonyConfFromClientContext() throws Exception {
|
public void testTonyConfFromClientContext() throws Exception {
|
||||||
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
|
||||||
runJobCli.run(
|
runJobCli.run(
|
||||||
new String[] {"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
|
new String[] {"--framework", "tensorflow", "--name", "my-job",
|
||||||
|
"--docker_image", "tf-docker:1.1.0",
|
||||||
"--input_path", "hdfs://input",
|
"--input_path", "hdfs://input",
|
||||||
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
|
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
|
||||||
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
|
||||||
"--ps_resources", "memory=4G,vcores=4", "--ps_launch_cmd",
|
"--ps_resources", "memory=4G,vcores=4", "--ps_launch_cmd",
|
||||||
"python run-ps.py"});
|
"python run-ps.py"});
|
||||||
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
|
||||||
|
|
||||||
|
assertTrue(RunJobParameters.class + " must be an instance of " +
|
||||||
|
TensorFlowRunJobParameters.class,
|
||||||
|
jobRunParameters instanceof TensorFlowRunJobParameters);
|
||||||
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) jobRunParameters;
|
||||||
|
|
||||||
Configuration tonyConf = TonyUtils
|
Configuration tonyConf = TonyUtils
|
||||||
.tonyConfFromClientContext(jobRunParameters);
|
.tonyConfFromClientContext(tensorFlowParams);
|
||||||
Assert.assertEquals(jobRunParameters.getDockerImageName(),
|
Assert.assertEquals(jobRunParameters.getDockerImageName(),
|
||||||
tonyConf.get(TonyConfigurationKeys.getContainerDockerKey()));
|
tonyConf.get(TonyConfigurationKeys.getContainerDockerKey()));
|
||||||
Assert.assertEquals("3", tonyConf.get(TonyConfigurationKeys
|
Assert.assertEquals("3", tonyConf.get(TonyConfigurationKeys
|
||||||
.getInstancesKey("worker")));
|
.getInstancesKey("worker")));
|
||||||
Assert.assertEquals(jobRunParameters.getWorkerLaunchCmd(),
|
Assert.assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
|
||||||
tonyConf.get(TonyConfigurationKeys
|
tonyConf.get(TonyConfigurationKeys
|
||||||
.getExecuteCommandKey("worker")));
|
.getExecuteCommandKey("worker")));
|
||||||
Assert.assertEquals("2048", tonyConf.get(TonyConfigurationKeys
|
Assert.assertEquals("2048", tonyConf.get(TonyConfigurationKeys
|
||||||
|
@ -107,7 +119,7 @@ public class TestTonyUtils {
|
||||||
Assert.assertEquals("4", tonyConf.get(TonyConfigurationKeys
|
Assert.assertEquals("4", tonyConf.get(TonyConfigurationKeys
|
||||||
.getResourceKey(Constants.PS_JOB_NAME,
|
.getResourceKey(Constants.PS_JOB_NAME,
|
||||||
Constants.VCORES)));
|
Constants.VCORES)));
|
||||||
Assert.assertEquals(jobRunParameters.getPSLaunchCmd(),
|
Assert.assertEquals(tensorFlowParams.getPSLaunchCmd(),
|
||||||
tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey("ps")));
|
tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey("ps")));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,8 +19,10 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.fs.Path;
|
import org.apache.hadoop.fs.Path;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
||||||
|
@ -28,7 +30,11 @@ import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchComma
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
|
||||||
|
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
|
||||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
|
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
|
||||||
|
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
|
||||||
|
import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Abstract base class for Component classes.
|
* Abstract base class for Component classes.
|
||||||
|
@ -40,7 +46,7 @@ import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.T
|
||||||
public abstract class AbstractComponent {
|
public abstract class AbstractComponent {
|
||||||
private final FileSystemOperations fsOperations;
|
private final FileSystemOperations fsOperations;
|
||||||
protected final RunJobParameters parameters;
|
protected final RunJobParameters parameters;
|
||||||
protected final TaskType taskType;
|
protected final Role role;
|
||||||
private final RemoteDirectoryManager remoteDirectoryManager;
|
private final RemoteDirectoryManager remoteDirectoryManager;
|
||||||
protected final Configuration yarnConfig;
|
protected final Configuration yarnConfig;
|
||||||
private final LaunchCommandFactory launchCommandFactory;
|
private final LaunchCommandFactory launchCommandFactory;
|
||||||
|
@ -52,19 +58,55 @@ public abstract class AbstractComponent {
|
||||||
|
|
||||||
public AbstractComponent(FileSystemOperations fsOperations,
|
public AbstractComponent(FileSystemOperations fsOperations,
|
||||||
RemoteDirectoryManager remoteDirectoryManager,
|
RemoteDirectoryManager remoteDirectoryManager,
|
||||||
RunJobParameters parameters, TaskType taskType,
|
RunJobParameters parameters, Role role,
|
||||||
Configuration yarnConfig,
|
Configuration yarnConfig,
|
||||||
LaunchCommandFactory launchCommandFactory) {
|
LaunchCommandFactory launchCommandFactory) {
|
||||||
this.fsOperations = fsOperations;
|
this.fsOperations = fsOperations;
|
||||||
this.remoteDirectoryManager = remoteDirectoryManager;
|
this.remoteDirectoryManager = remoteDirectoryManager;
|
||||||
this.parameters = parameters;
|
this.parameters = parameters;
|
||||||
this.taskType = taskType;
|
this.role = role;
|
||||||
this.launchCommandFactory = launchCommandFactory;
|
this.launchCommandFactory = launchCommandFactory;
|
||||||
this.yarnConfig = yarnConfig;
|
this.yarnConfig = yarnConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract Component createComponent() throws IOException;
|
protected abstract Component createComponent() throws IOException;
|
||||||
|
|
||||||
|
protected Component createComponentInternal() throws IOException {
|
||||||
|
Objects.requireNonNull(this.parameters.getWorkerResource(),
|
||||||
|
"Worker resource must not be null!");
|
||||||
|
if (parameters.getNumWorkers() < 1) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Number of workers should be at least 1!");
|
||||||
|
}
|
||||||
|
|
||||||
|
Component component = new Component();
|
||||||
|
component.setName(role.getComponentName());
|
||||||
|
|
||||||
|
if (role.equals(TensorFlowRole.PRIMARY_WORKER) ||
|
||||||
|
role.equals(PyTorchRole.PRIMARY_WORKER)) {
|
||||||
|
component.setNumberOfContainers(1L);
|
||||||
|
component.getConfiguration().setProperty(
|
||||||
|
CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
|
||||||
|
} else {
|
||||||
|
component.setNumberOfContainers(
|
||||||
|
(long) parameters.getNumWorkers() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parameters.getWorkerDockerImage() != null) {
|
||||||
|
component.setArtifact(
|
||||||
|
getDockerArtifact(parameters.getWorkerDockerImage()));
|
||||||
|
}
|
||||||
|
|
||||||
|
component.setResource(convertYarnResourceToServiceResource(
|
||||||
|
parameters.getWorkerResource()));
|
||||||
|
component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
|
||||||
|
|
||||||
|
addCommonEnvironments(component, role);
|
||||||
|
generateLaunchCommand(component);
|
||||||
|
|
||||||
|
return component;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generates a command launch script on local disk,
|
* Generates a command launch script on local disk,
|
||||||
* returns path to the script.
|
* returns path to the script.
|
||||||
|
@ -72,7 +114,7 @@ public abstract class AbstractComponent {
|
||||||
protected void generateLaunchCommand(Component component)
|
protected void generateLaunchCommand(Component component)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
AbstractLaunchCommand launchCommand =
|
AbstractLaunchCommand launchCommand =
|
||||||
launchCommandFactory.createLaunchCommand(taskType, component);
|
launchCommandFactory.createLaunchCommand(role, component);
|
||||||
this.localScriptFile = launchCommand.generateLaunchScript();
|
this.localScriptFile = launchCommand.generateLaunchScript();
|
||||||
|
|
||||||
String remoteLaunchCommand = uploadLaunchCommand(component);
|
String remoteLaunchCommand = uploadLaunchCommand(component);
|
||||||
|
@ -86,7 +128,7 @@ public abstract class AbstractComponent {
|
||||||
Path stagingDir =
|
Path stagingDir =
|
||||||
remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
|
remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
|
||||||
|
|
||||||
String destScriptFileName = getScriptFileName(taskType);
|
String destScriptFileName = getScriptFileName(role);
|
||||||
fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
|
fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
|
||||||
localScriptFile, destScriptFileName, component);
|
localScriptFile, destScriptFileName, component);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,167 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
|
||||||
|
|
||||||
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
|
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
|
||||||
|
import org.apache.hadoop.yarn.service.api.records.Service;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
||||||
|
import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
|
||||||
|
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
|
||||||
|
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
|
||||||
|
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
|
||||||
|
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Abstract base class that supports creating service specs for Native Service.
|
||||||
|
*/
|
||||||
|
public abstract class AbstractServiceSpec implements ServiceSpec {
|
||||||
|
private static final Logger LOG =
|
||||||
|
LoggerFactory.getLogger(AbstractServiceSpec.class);
|
||||||
|
protected final RunJobParameters parameters;
|
||||||
|
protected final FileSystemOperations fsOperations;
|
||||||
|
private final Localizer localizer;
|
||||||
|
protected final RemoteDirectoryManager remoteDirectoryManager;
|
||||||
|
protected final Configuration yarnConfig;
|
||||||
|
protected final LaunchCommandFactory launchCommandFactory;
|
||||||
|
private final WorkerComponentFactory workerFactory;
|
||||||
|
|
||||||
|
public AbstractServiceSpec(RunJobParameters parameters,
|
||||||
|
ClientContext clientContext, FileSystemOperations fsOperations,
|
||||||
|
LaunchCommandFactory launchCommandFactory,
|
||||||
|
Localizer localizer) {
|
||||||
|
this.parameters = parameters;
|
||||||
|
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
|
||||||
|
this.yarnConfig = clientContext.getYarnConfig();
|
||||||
|
this.fsOperations = fsOperations;
|
||||||
|
this.localizer = localizer;
|
||||||
|
this.launchCommandFactory = launchCommandFactory;
|
||||||
|
this.workerFactory = new WorkerComponentFactory(fsOperations,
|
||||||
|
remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected ServiceWrapper createServiceSpecWrapper() throws IOException {
|
||||||
|
Service serviceSpec = new Service();
|
||||||
|
serviceSpec.setName(parameters.getName());
|
||||||
|
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
|
||||||
|
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
|
||||||
|
|
||||||
|
KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
|
||||||
|
.create(fsOperations, remoteDirectoryManager, parameters);
|
||||||
|
if (kerberosPrincipal != null) {
|
||||||
|
serviceSpec.setKerberosPrincipal(kerberosPrincipal);
|
||||||
|
}
|
||||||
|
|
||||||
|
handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
|
||||||
|
localizer.handleLocalizations(serviceSpec);
|
||||||
|
return new ServiceWrapper(serviceSpec);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Handle worker and primary_worker.
|
||||||
|
protected void addWorkerComponents(ServiceWrapper serviceWrapper,
|
||||||
|
Framework framework)
|
||||||
|
throws IOException {
|
||||||
|
final Role primaryWorkerRole;
|
||||||
|
final Role workerRole;
|
||||||
|
if (framework == Framework.TENSORFLOW) {
|
||||||
|
primaryWorkerRole = TensorFlowRole.PRIMARY_WORKER;
|
||||||
|
workerRole = TensorFlowRole.WORKER;
|
||||||
|
} else {
|
||||||
|
primaryWorkerRole = PyTorchRole.PRIMARY_WORKER;
|
||||||
|
workerRole = PyTorchRole.WORKER;
|
||||||
|
}
|
||||||
|
|
||||||
|
addWorkerComponent(serviceWrapper, primaryWorkerRole, framework);
|
||||||
|
|
||||||
|
if (parameters.getNumWorkers() > 1) {
|
||||||
|
addWorkerComponent(serviceWrapper, workerRole, framework);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private void addWorkerComponent(ServiceWrapper serviceWrapper,
|
||||||
|
Role role, Framework framework) throws IOException {
|
||||||
|
AbstractComponent component = workerFactory.create(framework, role);
|
||||||
|
serviceWrapper.addComponent(component);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void handleQuicklinks(Service serviceSpec)
|
||||||
|
throws IOException {
|
||||||
|
List<Quicklink> quicklinks = parameters.getQuicklinks();
|
||||||
|
if (quicklinks != null && !quicklinks.isEmpty()) {
|
||||||
|
for (Quicklink ql : quicklinks) {
|
||||||
|
// Make sure it is a valid instance name
|
||||||
|
String instanceName = ql.getComponentInstanceName();
|
||||||
|
boolean found = false;
|
||||||
|
|
||||||
|
for (Component comp : serviceSpec.getComponents()) {
|
||||||
|
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
|
||||||
|
String possibleInstanceName = comp.getName() + "-" + i;
|
||||||
|
if (possibleInstanceName.equals(instanceName)) {
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!found) {
|
||||||
|
throw new IOException(
|
||||||
|
"Couldn't find a component instance = " + instanceName
|
||||||
|
+ " while adding quicklink");
|
||||||
|
}
|
||||||
|
|
||||||
|
String link = ql.getProtocol()
|
||||||
|
+ YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
|
||||||
|
getUserName(), getDNSDomain(yarnConfig), ql.getPort());
|
||||||
|
addQuicklink(serviceSpec, ql.getLabel(), link);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static void addQuicklink(Service serviceSpec, String label,
|
||||||
|
String link) {
|
||||||
|
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
|
||||||
|
if (quicklinks == null) {
|
||||||
|
quicklinks = new HashMap<>();
|
||||||
|
serviceSpec.setQuicklinks(quicklinks);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (SubmarineLogs.isVerbose()) {
|
||||||
|
LOG.info("Added quicklink, " + label + "=" + link);
|
||||||
|
}
|
||||||
|
|
||||||
|
quicklinks.put(label, link);
|
||||||
|
}
|
||||||
|
}
|
|
@ -37,6 +37,7 @@ import java.io.File;
|
||||||
import java.io.FileNotFoundException;
|
import java.io.FileNotFoundException;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -195,6 +196,15 @@ public class FileSystemOperations {
|
||||||
fs.setPermission(destPath, new FsPermission(permission));
|
fs.setPermission(destPath, new FsPermission(permission));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean needHdfs(List<String> stringsToCheck) {
|
||||||
|
for (String content : stringsToCheck) {
|
||||||
|
if (content != null && content.contains("hdfs://")) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
public static boolean needHdfs(String content) {
|
public static boolean needHdfs(String content) {
|
||||||
return content != null && content.contains("hdfs://");
|
return content != null && content.contains("hdfs://");
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,10 @@
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
|
||||||
|
|
||||||
|
import org.apache.curator.shaded.com.google.common.collect.ImmutableList;
|
||||||
import org.apache.hadoop.fs.Path;
|
import org.apache.hadoop.fs.Path;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||||
|
@ -28,6 +29,8 @@ import org.slf4j.LoggerFactory;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.PrintWriter;
|
import java.io.PrintWriter;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
|
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
|
||||||
import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
|
import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
|
||||||
|
@ -128,10 +131,22 @@ public class HadoopEnvironmentSetup {
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
|
private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
|
||||||
return needHdfs(parameters.getInputPath()) ||
|
List<String> launchCommands = parameters.getLaunchCommands();
|
||||||
needHdfs(parameters.getPSLaunchCmd()) ||
|
if (launchCommands != null) {
|
||||||
needHdfs(parameters.getWorkerLaunchCmd()) ||
|
launchCommands.removeIf(Objects::isNull);
|
||||||
hadoopEnv;
|
}
|
||||||
|
|
||||||
|
ImmutableList.Builder<String> listBuilder = ImmutableList.builder();
|
||||||
|
|
||||||
|
if (launchCommands != null && !launchCommands.isEmpty()) {
|
||||||
|
listBuilder.addAll(launchCommands);
|
||||||
|
}
|
||||||
|
if (parameters.getInputPath() != null) {
|
||||||
|
listBuilder.add(parameters.getInputPath());
|
||||||
|
}
|
||||||
|
List<String> stringsToCheck = listBuilder.build();
|
||||||
|
|
||||||
|
return needHdfs(stringsToCheck) || hadoopEnv;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void appendHdfsHome(PrintWriter fw, String hdfsHome) {
|
private void appendHdfsHome(PrintWriter fw, String hdfsHome) {
|
||||||
|
|
|
@ -38,7 +38,7 @@ public final class ServiceSpecFileGenerator {
|
||||||
"instantiated!");
|
"instantiated!");
|
||||||
}
|
}
|
||||||
|
|
||||||
static String generateJson(Service service) throws IOException {
|
public static String generateJson(Service service) throws IOException {
|
||||||
File serviceSpecFile = File.createTempFile(service.getName(), ".json");
|
File serviceSpecFile = File.createTempFile(service.getName(), ".json");
|
||||||
String buffer = jsonSerDeser.toJson(service);
|
String buffer = jsonSerDeser.toJson(service);
|
||||||
Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),
|
Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
|
||||||
|
|
||||||
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component.PyTorchWorkerComponent;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory class that helps creating Native Service components.
|
||||||
|
*/
|
||||||
|
public class WorkerComponentFactory {
|
||||||
|
private final FileSystemOperations fsOperations;
|
||||||
|
private final RemoteDirectoryManager remoteDirectoryManager;
|
||||||
|
private final RunJobParameters parameters;
|
||||||
|
private final LaunchCommandFactory launchCommandFactory;
|
||||||
|
private final Configuration yarnConfig;
|
||||||
|
|
||||||
|
WorkerComponentFactory(FileSystemOperations fsOperations,
|
||||||
|
RemoteDirectoryManager remoteDirectoryManager,
|
||||||
|
RunJobParameters parameters,
|
||||||
|
LaunchCommandFactory launchCommandFactory,
|
||||||
|
Configuration yarnConfig) {
|
||||||
|
this.fsOperations = fsOperations;
|
||||||
|
this.remoteDirectoryManager = remoteDirectoryManager;
|
||||||
|
this.parameters = parameters;
|
||||||
|
this.launchCommandFactory = launchCommandFactory;
|
||||||
|
this.yarnConfig = yarnConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates either a TensorFlow or a PyTorch Native Service component.
|
||||||
|
*/
|
||||||
|
public AbstractComponent create(Framework framework, Role role) {
|
||||||
|
if (framework == Framework.TENSORFLOW) {
|
||||||
|
return new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
|
||||||
|
(TensorFlowRunJobParameters) parameters, role,
|
||||||
|
(TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
|
||||||
|
} else if (framework == Framework.PYTORCH) {
|
||||||
|
return new PyTorchWorkerComponent(fsOperations, remoteDirectoryManager,
|
||||||
|
(PyTorchRunJobParameters) parameters, role,
|
||||||
|
(PyTorchLaunchCommandFactory) launchCommandFactory, yarnConfig);
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException("Only supported frameworks are: "
|
||||||
|
+ Framework.getValues());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,10 +20,16 @@ import org.apache.hadoop.yarn.client.api.AppAdminClient;
|
||||||
import org.apache.hadoop.yarn.exceptions.YarnException;
|
import org.apache.hadoop.yarn.exceptions.YarnException;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Service;
|
import org.apache.hadoop.yarn.service.api.records.Service;
|
||||||
import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
|
import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.PyTorchServiceSpec;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
|
||||||
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
@ -32,6 +38,7 @@ import org.slf4j.LoggerFactory;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
|
import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
|
||||||
|
import static org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder.SUPPORTED_FRAMEWORKS_MESSAGE;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Submit a job to cluster.
|
* Submit a job to cluster.
|
||||||
|
@ -51,14 +58,45 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
|
||||||
* {@inheritDoc}
|
* {@inheritDoc}
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public ApplicationId submitJob(RunJobParameters parameters)
|
public ApplicationId submitJob(ParametersHolder paramsHolder)
|
||||||
throws IOException, YarnException {
|
throws IOException, YarnException {
|
||||||
|
Framework framework = paramsHolder.getFramework();
|
||||||
|
RunJobParameters parameters =
|
||||||
|
(RunJobParameters) paramsHolder.getParameters();
|
||||||
|
|
||||||
|
if (framework == Framework.TENSORFLOW) {
|
||||||
|
return submitTensorFlowJob((TensorFlowRunJobParameters) parameters);
|
||||||
|
} else if (framework == Framework.PYTORCH) {
|
||||||
|
return submitPyTorchJob((PyTorchRunJobParameters) parameters);
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ApplicationId submitTensorFlowJob(
|
||||||
|
TensorFlowRunJobParameters parameters) throws IOException, YarnException {
|
||||||
FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
|
FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
|
||||||
HadoopEnvironmentSetup hadoopEnvSetup =
|
HadoopEnvironmentSetup hadoopEnvSetup =
|
||||||
new HadoopEnvironmentSetup(clientContext, fsOperations);
|
new HadoopEnvironmentSetup(clientContext, fsOperations);
|
||||||
|
|
||||||
Service serviceSpec = createTensorFlowServiceSpec(parameters,
|
Service serviceSpec = createTensorFlowServiceSpec(parameters,
|
||||||
fsOperations, hadoopEnvSetup);
|
fsOperations, hadoopEnvSetup);
|
||||||
|
return submitJobInternal(serviceSpec);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ApplicationId submitPyTorchJob(PyTorchRunJobParameters parameters)
|
||||||
|
throws IOException, YarnException {
|
||||||
|
FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
|
||||||
|
HadoopEnvironmentSetup hadoopEnvSetup =
|
||||||
|
new HadoopEnvironmentSetup(clientContext, fsOperations);
|
||||||
|
|
||||||
|
Service serviceSpec = createPyTorchServiceSpec(parameters,
|
||||||
|
fsOperations, hadoopEnvSetup);
|
||||||
|
return submitJobInternal(serviceSpec);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ApplicationId submitJobInternal(Service serviceSpec)
|
||||||
|
throws IOException, YarnException {
|
||||||
String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
|
String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
|
||||||
|
|
||||||
AppAdminClient appAdminClient =
|
AppAdminClient appAdminClient =
|
||||||
|
@ -70,7 +108,7 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
|
||||||
"Fail to launch application with exit code:" + code);
|
"Fail to launch application with exit code:" + code);
|
||||||
}
|
}
|
||||||
|
|
||||||
String appStatus=appAdminClient.getStatusString(serviceSpec.getName());
|
String appStatus = appAdminClient.getStatusString(serviceSpec.getName());
|
||||||
Service app = ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
|
Service app = ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
|
||||||
|
|
||||||
// Retry multiple times if applicationId is null
|
// Retry multiple times if applicationId is null
|
||||||
|
@ -97,11 +135,12 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
|
||||||
return appid;
|
return appid;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Service createTensorFlowServiceSpec(RunJobParameters parameters,
|
private Service createTensorFlowServiceSpec(
|
||||||
|
TensorFlowRunJobParameters parameters,
|
||||||
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
|
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
LaunchCommandFactory launchCommandFactory =
|
TensorFlowLaunchCommandFactory launchCommandFactory =
|
||||||
new LaunchCommandFactory(hadoopEnvSetup, parameters,
|
new TensorFlowLaunchCommandFactory(hadoopEnvSetup, parameters,
|
||||||
clientContext.getYarnConfig());
|
clientContext.getYarnConfig());
|
||||||
Localizer localizer = new Localizer(fsOperations,
|
Localizer localizer = new Localizer(fsOperations,
|
||||||
clientContext.getRemoteDirectoryManager(), parameters);
|
clientContext.getRemoteDirectoryManager(), parameters);
|
||||||
|
@ -113,6 +152,22 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
|
||||||
return serviceWrapper.getService();
|
return serviceWrapper.getService();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Service createPyTorchServiceSpec(PyTorchRunJobParameters parameters,
|
||||||
|
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
|
||||||
|
throws IOException {
|
||||||
|
PyTorchLaunchCommandFactory launchCommandFactory =
|
||||||
|
new PyTorchLaunchCommandFactory(hadoopEnvSetup, parameters,
|
||||||
|
clientContext.getYarnConfig());
|
||||||
|
Localizer localizer = new Localizer(fsOperations,
|
||||||
|
clientContext.getRemoteDirectoryManager(), parameters);
|
||||||
|
PyTorchServiceSpec pyTorchServiceSpec = new PyTorchServiceSpec(
|
||||||
|
parameters, this.clientContext, fsOperations, launchCommandFactory,
|
||||||
|
localizer);
|
||||||
|
|
||||||
|
serviceWrapper = pyTorchServiceSpec.create();
|
||||||
|
return serviceWrapper.getService();
|
||||||
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
public ServiceWrapper getServiceWrapper() {
|
public ServiceWrapper getServiceWrapper() {
|
||||||
return serviceWrapper;
|
return serviceWrapper;
|
||||||
|
|
|
@ -17,11 +17,9 @@
|
||||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||||
|
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Abstract base class for Launch command implementations for Services.
|
* Abstract base class for Launch command implementations for Services.
|
||||||
|
@ -32,10 +30,9 @@ public abstract class AbstractLaunchCommand {
|
||||||
private final LaunchScriptBuilder builder;
|
private final LaunchScriptBuilder builder;
|
||||||
|
|
||||||
public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||||
TaskType taskType, Component component, RunJobParameters parameters)
|
Component component, RunJobParameters parameters,
|
||||||
throws IOException {
|
String launchCommandPrefix) throws IOException {
|
||||||
Objects.requireNonNull(taskType, "TaskType must not be null!");
|
this.builder = new LaunchScriptBuilder(launchCommandPrefix, hadoopEnvSetup,
|
||||||
this.builder = new LaunchScriptBuilder(taskType.name(), hadoopEnvSetup,
|
|
||||||
parameters, component);
|
parameters, component);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,52 +16,15 @@
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||||
|
|
||||||
import org.apache.hadoop.conf.Configuration;
|
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simple factory to create instances of {@link AbstractLaunchCommand}
|
* Interface for creating launch commands.
|
||||||
* based on the {@link TaskType}.
|
|
||||||
* All dependencies are passed to this factory that could be required
|
|
||||||
* by any implementor of {@link AbstractLaunchCommand}.
|
|
||||||
*/
|
*/
|
||||||
public class LaunchCommandFactory {
|
public interface LaunchCommandFactory {
|
||||||
private final HadoopEnvironmentSetup hadoopEnvSetup;
|
AbstractLaunchCommand createLaunchCommand(Role role, Component component)
|
||||||
private final RunJobParameters parameters;
|
throws IOException;
|
||||||
private final Configuration yarnConfig;
|
|
||||||
|
|
||||||
public LaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
|
|
||||||
RunJobParameters parameters, Configuration yarnConfig) {
|
|
||||||
this.hadoopEnvSetup = hadoopEnvSetup;
|
|
||||||
this.parameters = parameters;
|
|
||||||
this.yarnConfig = yarnConfig;
|
|
||||||
}
|
|
||||||
|
|
||||||
public AbstractLaunchCommand createLaunchCommand(TaskType taskType,
|
|
||||||
Component component) throws IOException {
|
|
||||||
Objects.requireNonNull(taskType, "TaskType must not be null!");
|
|
||||||
|
|
||||||
if (taskType == TaskType.WORKER || taskType == TaskType.PRIMARY_WORKER) {
|
|
||||||
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
|
|
||||||
component, parameters, yarnConfig);
|
|
||||||
|
|
||||||
} else if (taskType == TaskType.PS) {
|
|
||||||
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
|
|
||||||
parameters, yarnConfig);
|
|
||||||
|
|
||||||
} else if (taskType == TaskType.TENSORBOARD) {
|
|
||||||
return new TensorBoardLaunchCommand(hadoopEnvSetup, taskType, component,
|
|
||||||
parameters);
|
|
||||||
}
|
|
||||||
throw new IllegalStateException("Unknown task type: " + taskType);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||||
|
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
@ -47,10 +47,11 @@ public class LaunchScriptBuilder {
|
||||||
private final StringBuilder scriptBuffer;
|
private final StringBuilder scriptBuffer;
|
||||||
private String launchCommand;
|
private String launchCommand;
|
||||||
|
|
||||||
LaunchScriptBuilder(String namePrefix,
|
LaunchScriptBuilder(String launchScriptPrefix,
|
||||||
HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
|
HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
|
||||||
Component component) throws IOException {
|
Component component) throws IOException {
|
||||||
this.file = File.createTempFile(namePrefix + "-launch-script", ".sh");
|
this.file = File.createTempFile(launchScriptPrefix +
|
||||||
|
"-launch-script", ".sh");
|
||||||
this.hadoopEnvSetup = hadoopEnvSetup;
|
this.hadoopEnvSetup = hadoopEnvSetup;
|
||||||
this.parameters = parameters;
|
this.parameters = parameters;
|
||||||
this.component = component;
|
this.component = component;
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command.PyTorchWorkerLaunchCommand;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple factory to create instances of {@link AbstractLaunchCommand}
|
||||||
|
* based on the {@link Role}.
|
||||||
|
* All dependencies are passed to this factory that could be required
|
||||||
|
* by any implementor of {@link AbstractLaunchCommand}.
|
||||||
|
*/
|
||||||
|
public class PyTorchLaunchCommandFactory implements LaunchCommandFactory {
|
||||||
|
private final HadoopEnvironmentSetup hadoopEnvSetup;
|
||||||
|
private final PyTorchRunJobParameters parameters;
|
||||||
|
private final Configuration yarnConfig;
|
||||||
|
|
||||||
|
public PyTorchLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||||
|
PyTorchRunJobParameters parameters, Configuration yarnConfig) {
|
||||||
|
this.hadoopEnvSetup = hadoopEnvSetup;
|
||||||
|
this.parameters = parameters;
|
||||||
|
this.yarnConfig = yarnConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
public AbstractLaunchCommand createLaunchCommand(Role role,
|
||||||
|
Component component) throws IOException {
|
||||||
|
Objects.requireNonNull(role, "Role must not be null!");
|
||||||
|
|
||||||
|
if (role == PyTorchRole.WORKER ||
|
||||||
|
role == PyTorchRole.PRIMARY_WORKER) {
|
||||||
|
return new PyTorchWorkerLaunchCommand(hadoopEnvSetup, role,
|
||||||
|
component, parameters, yarnConfig);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
throw new IllegalStateException("Unknown task type: " + role);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
|
||||||
|
|
||||||
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple factory to create instances of {@link AbstractLaunchCommand}
|
||||||
|
* based on the {@link Role}.
|
||||||
|
* All dependencies are passed to this factory that could be required
|
||||||
|
* by any implementor of {@link AbstractLaunchCommand}.
|
||||||
|
*/
|
||||||
|
public class TensorFlowLaunchCommandFactory implements LaunchCommandFactory {
|
||||||
|
private final HadoopEnvironmentSetup hadoopEnvSetup;
|
||||||
|
private final TensorFlowRunJobParameters parameters;
|
||||||
|
private final Configuration yarnConfig;
|
||||||
|
|
||||||
|
public TensorFlowLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||||
|
TensorFlowRunJobParameters parameters, Configuration yarnConfig) {
|
||||||
|
this.hadoopEnvSetup = hadoopEnvSetup;
|
||||||
|
this.parameters = parameters;
|
||||||
|
this.yarnConfig = yarnConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AbstractLaunchCommand createLaunchCommand(Role role,
|
||||||
|
Component component) throws IOException {
|
||||||
|
Objects.requireNonNull(role, "Role must not be null!");
|
||||||
|
|
||||||
|
if (role == TensorFlowRole.WORKER ||
|
||||||
|
role == TensorFlowRole.PRIMARY_WORKER) {
|
||||||
|
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, role,
|
||||||
|
component, parameters, yarnConfig);
|
||||||
|
|
||||||
|
} else if (role == TensorFlowRole.PS) {
|
||||||
|
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, role, component,
|
||||||
|
parameters, yarnConfig);
|
||||||
|
|
||||||
|
} else if (role == TensorFlowRole.TENSORBOARD) {
|
||||||
|
return new TensorBoardLaunchCommand(hadoopEnvSetup, role, component,
|
||||||
|
parameters);
|
||||||
|
}
|
||||||
|
throw new IllegalStateException("Unknown task type: " + role);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import org.apache.hadoop.yarn.service.api.records.Service;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
|
||||||
|
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class contains all the logic to create an instance
|
||||||
|
* of a {@link Service} object for PyTorch.
|
||||||
|
* Please note that currently, only single-node (non-distributed)
|
||||||
|
* support is implemented for PyTorch.
|
||||||
|
*/
|
||||||
|
public class PyTorchServiceSpec extends AbstractServiceSpec {
|
||||||
|
private static final Logger LOG =
|
||||||
|
LoggerFactory.getLogger(PyTorchServiceSpec.class);
|
||||||
|
//this field is needed in the future!
|
||||||
|
private final PyTorchRunJobParameters pyTorchParameters;
|
||||||
|
|
||||||
|
public PyTorchServiceSpec(PyTorchRunJobParameters parameters,
|
||||||
|
ClientContext clientContext, FileSystemOperations fsOperations,
|
||||||
|
PyTorchLaunchCommandFactory launchCommandFactory, Localizer localizer) {
|
||||||
|
super(parameters, clientContext, fsOperations, launchCommandFactory,
|
||||||
|
localizer);
|
||||||
|
this.pyTorchParameters = parameters;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ServiceWrapper create() throws IOException {
|
||||||
|
LOG.info("Creating PyTorch service spec");
|
||||||
|
ServiceWrapper serviceWrapper = createServiceSpecWrapper();
|
||||||
|
|
||||||
|
if (parameters.getNumWorkers() > 0) {
|
||||||
|
addWorkerComponents(serviceWrapper, Framework.PYTORCH);
|
||||||
|
}
|
||||||
|
|
||||||
|
// After all components added, handle quicklinks
|
||||||
|
handleQuicklinks(serviceWrapper.getService());
|
||||||
|
|
||||||
|
return serviceWrapper;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,87 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Launch command implementation for PyTorch components.
|
||||||
|
*/
|
||||||
|
public class PyTorchWorkerLaunchCommand extends AbstractLaunchCommand {
|
||||||
|
private static final Logger LOG =
|
||||||
|
LoggerFactory.getLogger(PyTorchWorkerLaunchCommand.class);
|
||||||
|
private final Configuration yarnConfig;
|
||||||
|
private final boolean distributed;
|
||||||
|
private final int numberOfWorkers;
|
||||||
|
private final String name;
|
||||||
|
private final Role role;
|
||||||
|
private final String launchCommand;
|
||||||
|
|
||||||
|
public PyTorchWorkerLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||||
|
Role role, Component component,
|
||||||
|
PyTorchRunJobParameters parameters,
|
||||||
|
Configuration yarnConfig) throws IOException {
|
||||||
|
super(hadoopEnvSetup, component, parameters, role.getName());
|
||||||
|
this.role = role;
|
||||||
|
this.name = parameters.getName();
|
||||||
|
this.distributed = parameters.isDistributed();
|
||||||
|
this.numberOfWorkers = parameters.getNumWorkers();
|
||||||
|
this.yarnConfig = yarnConfig;
|
||||||
|
logReceivedParameters();
|
||||||
|
|
||||||
|
this.launchCommand = parameters.getWorkerLaunchCmd();
|
||||||
|
|
||||||
|
if (StringUtils.isEmpty(this.launchCommand)) {
|
||||||
|
throw new IllegalArgumentException("LaunchCommand must not be null " +
|
||||||
|
"or empty!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void logReceivedParameters() {
|
||||||
|
if (this.numberOfWorkers <= 0) {
|
||||||
|
LOG.warn("Received number of workers: {}", this.numberOfWorkers);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String generateLaunchScript() throws IOException {
|
||||||
|
LaunchScriptBuilder builder = getBuilder();
|
||||||
|
return builder
|
||||||
|
.withLaunchCommand(createLaunchCommand())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String createLaunchCommand() {
|
||||||
|
if (SubmarineLogs.isVerbose()) {
|
||||||
|
LOG.info("PyTorch Worker command =[" + launchCommand + "]");
|
||||||
|
}
|
||||||
|
return launchCommand + '\n';
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* This package contains classes to generate PyTorch launch commands.
|
||||||
|
*/
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;
|
||||||
|
|
||||||
|
import org.apache.hadoop.conf.Configuration;
|
||||||
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
||||||
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Component implementation for Worker process of PyTorch.
|
||||||
|
*/
|
||||||
|
public class PyTorchWorkerComponent extends AbstractComponent {
|
||||||
|
public PyTorchWorkerComponent(FileSystemOperations fsOperations,
|
||||||
|
RemoteDirectoryManager remoteDirectoryManager,
|
||||||
|
PyTorchRunJobParameters parameters, Role role,
|
||||||
|
PyTorchLaunchCommandFactory launchCommandFactory,
|
||||||
|
Configuration yarnConfig) {
|
||||||
|
super(fsOperations, remoteDirectoryManager, parameters, role,
|
||||||
|
yarnConfig, launchCommandFactory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Component createComponent() throws IOException {
|
||||||
|
return createComponentInternal();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* This package contains classes to generate
|
||||||
|
* PyTorch Native Service components.
|
||||||
|
*/
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;
|
|
@ -0,0 +1,20 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* This package contains classes to generate
|
||||||
|
* PyTorch-related Native Service runtime artifacts.
|
||||||
|
*/
|
||||||
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;
|
|
@ -20,7 +20,7 @@ import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
|
import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.common.Envs;
|
import org.apache.hadoop.yarn.submarine.common.Envs;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -35,10 +35,10 @@ public final class TensorFlowCommons {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void addCommonEnvironments(Component component,
|
public static void addCommonEnvironments(Component component,
|
||||||
TaskType taskType) {
|
Role role) {
|
||||||
Map<String, String> envs = component.getConfiguration().getEnv();
|
Map<String, String> envs = component.getConfiguration().getEnv();
|
||||||
envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
|
envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
|
||||||
envs.put(Envs.TASK_TYPE_ENV, taskType.name());
|
envs.put(Envs.TASK_TYPE_ENV, role.getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String getUserName() {
|
public static String getUserName() {
|
||||||
|
@ -49,8 +49,8 @@ public final class TensorFlowCommons {
|
||||||
return yarnConfig.get("hadoop.registry.dns.domain-name");
|
return yarnConfig.get("hadoop.registry.dns.domain-name");
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String getScriptFileName(TaskType taskType) {
|
public static String getScriptFileName(Role role) {
|
||||||
return "run-" + taskType.name() + ".sh";
|
return "run-" + role.getName() + ".sh";
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String getTFConfigEnv(String componentName, int nWorkers,
|
public static String getTFConfigEnv(String componentName, int nWorkers,
|
||||||
|
|
|
@ -16,39 +16,24 @@
|
||||||
|
|
||||||
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
|
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
|
||||||
|
|
||||||
import org.apache.hadoop.conf.Configuration;
|
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
|
||||||
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
|
|
||||||
import org.apache.hadoop.yarn.service.api.records.Service;
|
import org.apache.hadoop.yarn.service.api.records.Service;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
|
||||||
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
import org.apache.hadoop.yarn.submarine.common.ClientContext;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
|
||||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
|
||||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpec;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
|
|
||||||
import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
|
|
||||||
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
import org.apache.hadoop.yarn.submarine.utils.Localizer;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
|
|
||||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
|
|
||||||
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
|
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
|
||||||
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
|
|
||||||
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class contains all the logic to create an instance
|
* This class contains all the logic to create an instance
|
||||||
|
@ -56,42 +41,34 @@ import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handle
|
||||||
* Worker,PS and Tensorboard components are added to the Service
|
* Worker,PS and Tensorboard components are added to the Service
|
||||||
* based on the value of the received {@link RunJobParameters}.
|
* based on the value of the received {@link RunJobParameters}.
|
||||||
*/
|
*/
|
||||||
public class TensorFlowServiceSpec implements ServiceSpec {
|
public class TensorFlowServiceSpec extends AbstractServiceSpec {
|
||||||
private static final Logger LOG =
|
private static final Logger LOG =
|
||||||
LoggerFactory.getLogger(TensorFlowServiceSpec.class);
|
LoggerFactory.getLogger(TensorFlowServiceSpec.class);
|
||||||
|
private final TensorFlowRunJobParameters tensorFlowParameters;
|
||||||
|
|
||||||
private final RemoteDirectoryManager remoteDirectoryManager;
|
public TensorFlowServiceSpec(TensorFlowRunJobParameters parameters,
|
||||||
|
|
||||||
private final RunJobParameters parameters;
|
|
||||||
private final Configuration yarnConfig;
|
|
||||||
private final FileSystemOperations fsOperations;
|
|
||||||
private final LaunchCommandFactory launchCommandFactory;
|
|
||||||
private final Localizer localizer;
|
|
||||||
|
|
||||||
public TensorFlowServiceSpec(RunJobParameters parameters,
|
|
||||||
ClientContext clientContext, FileSystemOperations fsOperations,
|
ClientContext clientContext, FileSystemOperations fsOperations,
|
||||||
LaunchCommandFactory launchCommandFactory, Localizer localizer) {
|
TensorFlowLaunchCommandFactory launchCommandFactory,
|
||||||
this.parameters = parameters;
|
Localizer localizer) {
|
||||||
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
|
super(parameters, clientContext, fsOperations, launchCommandFactory,
|
||||||
this.yarnConfig = clientContext.getYarnConfig();
|
localizer);
|
||||||
this.fsOperations = fsOperations;
|
this.tensorFlowParameters = parameters;
|
||||||
this.launchCommandFactory = launchCommandFactory;
|
|
||||||
this.localizer = localizer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ServiceWrapper create() throws IOException {
|
public ServiceWrapper create() throws IOException {
|
||||||
|
LOG.info("Creating TensorFlow service spec");
|
||||||
ServiceWrapper serviceWrapper = createServiceSpecWrapper();
|
ServiceWrapper serviceWrapper = createServiceSpecWrapper();
|
||||||
|
|
||||||
if (parameters.getNumWorkers() > 0) {
|
if (tensorFlowParameters.getNumWorkers() > 0) {
|
||||||
addWorkerComponents(serviceWrapper);
|
addWorkerComponents(serviceWrapper, Framework.TENSORFLOW);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parameters.getNumPS() > 0) {
|
if (tensorFlowParameters.getNumPS() > 0) {
|
||||||
addPsComponent(serviceWrapper);
|
addPsComponent(serviceWrapper);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parameters.isTensorboardEnabled()) {
|
if (tensorFlowParameters.isTensorboardEnabled()) {
|
||||||
createTensorBoardComponent(serviceWrapper);
|
createTensorBoardComponent(serviceWrapper);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,103 +78,23 @@ public class TensorFlowServiceSpec implements ServiceSpec {
|
||||||
return serviceWrapper;
|
return serviceWrapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
private ServiceWrapper createServiceSpecWrapper() throws IOException {
|
|
||||||
Service serviceSpec = new Service();
|
|
||||||
serviceSpec.setName(parameters.getName());
|
|
||||||
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
|
|
||||||
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
|
|
||||||
|
|
||||||
KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
|
|
||||||
.create(fsOperations, remoteDirectoryManager, parameters);
|
|
||||||
if (kerberosPrincipal != null) {
|
|
||||||
serviceSpec.setKerberosPrincipal(kerberosPrincipal);
|
|
||||||
}
|
|
||||||
|
|
||||||
handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
|
|
||||||
localizer.handleLocalizations(serviceSpec);
|
|
||||||
return new ServiceWrapper(serviceSpec);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
|
private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
|
TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
|
||||||
remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
|
remoteDirectoryManager, parameters,
|
||||||
|
(TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
|
||||||
serviceWrapper.addComponent(tbComponent);
|
serviceWrapper.addComponent(tbComponent);
|
||||||
|
|
||||||
addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
|
addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
|
||||||
tbComponent.getTensorboardLink());
|
tbComponent.getTensorboardLink());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void addQuicklink(Service serviceSpec, String label,
|
|
||||||
String link) {
|
|
||||||
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
|
|
||||||
if (quicklinks == null) {
|
|
||||||
quicklinks = new HashMap<>();
|
|
||||||
serviceSpec.setQuicklinks(quicklinks);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (SubmarineLogs.isVerbose()) {
|
|
||||||
LOG.info("Added quicklink, " + label + "=" + link);
|
|
||||||
}
|
|
||||||
|
|
||||||
quicklinks.put(label, link);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void handleQuicklinks(Service serviceSpec)
|
|
||||||
throws IOException {
|
|
||||||
List<Quicklink> quicklinks = parameters.getQuicklinks();
|
|
||||||
if (quicklinks != null && !quicklinks.isEmpty()) {
|
|
||||||
for (Quicklink ql : quicklinks) {
|
|
||||||
// Make sure it is a valid instance name
|
|
||||||
String instanceName = ql.getComponentInstanceName();
|
|
||||||
boolean found = false;
|
|
||||||
|
|
||||||
for (Component comp : serviceSpec.getComponents()) {
|
|
||||||
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
|
|
||||||
String possibleInstanceName = comp.getName() + "-" + i;
|
|
||||||
if (possibleInstanceName.equals(instanceName)) {
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!found) {
|
|
||||||
throw new IOException(
|
|
||||||
"Couldn't find a component instance = " + instanceName
|
|
||||||
+ " while adding quicklink");
|
|
||||||
}
|
|
||||||
|
|
||||||
String link = ql.getProtocol()
|
|
||||||
+ YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
|
|
||||||
getUserName(), getDNSDomain(yarnConfig), ql.getPort());
|
|
||||||
addQuicklink(serviceSpec, ql.getLabel(), link);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle worker and primary_worker.
|
|
||||||
|
|
||||||
private void addWorkerComponents(ServiceWrapper serviceWrapper)
|
|
||||||
throws IOException {
|
|
||||||
addWorkerComponent(serviceWrapper, parameters, TaskType.PRIMARY_WORKER);
|
|
||||||
|
|
||||||
if (parameters.getNumWorkers() > 1) {
|
|
||||||
addWorkerComponent(serviceWrapper, parameters, TaskType.WORKER);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
private void addWorkerComponent(ServiceWrapper serviceWrapper,
|
|
||||||
RunJobParameters parameters, TaskType taskType) throws IOException {
|
|
||||||
serviceWrapper.addComponent(
|
|
||||||
new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
|
|
||||||
parameters, taskType, launchCommandFactory, yarnConfig));
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addPsComponent(ServiceWrapper serviceWrapper)
|
private void addPsComponent(ServiceWrapper serviceWrapper)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
serviceWrapper.addComponent(
|
serviceWrapper.addComponent(
|
||||||
new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
|
new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
|
||||||
launchCommandFactory, parameters, yarnConfig));
|
(TensorFlowLaunchCommandFactory) launchCommandFactory,
|
||||||
|
parameters, yarnConfig));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
|
||||||
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
@ -37,9 +37,9 @@ public class TensorBoardLaunchCommand extends AbstractLaunchCommand {
|
||||||
private final String checkpointPath;
|
private final String checkpointPath;
|
||||||
|
|
||||||
public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||||
TaskType taskType, Component component, RunJobParameters parameters)
|
Role role, Component component, RunJobParameters parameters)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
super(hadoopEnvSetup, taskType, component, parameters);
|
super(hadoopEnvSetup, component, parameters, role.getName());
|
||||||
Objects.requireNonNull(parameters.getCheckpointPath(),
|
Objects.requireNonNull(parameters.getCheckpointPath(),
|
||||||
"CheckpointPath must not be null as it is part "
|
"CheckpointPath must not be null as it is part "
|
||||||
+ "of the tensorboard command!");
|
+ "of the tensorboard command!");
|
||||||
|
|
|
@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
|
||||||
|
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
|
||||||
|
@ -28,6 +28,7 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Launch command implementation for
|
* Launch command implementation for
|
||||||
|
@ -41,13 +42,16 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
|
||||||
private final int numberOfWorkers;
|
private final int numberOfWorkers;
|
||||||
private final int numberOfPS;
|
private final int numberOfPS;
|
||||||
private final String name;
|
private final String name;
|
||||||
private final TaskType taskType;
|
private final Role role;
|
||||||
|
|
||||||
TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||||
TaskType taskType, Component component, RunJobParameters parameters,
|
Role role, Component component,
|
||||||
|
TensorFlowRunJobParameters parameters,
|
||||||
Configuration yarnConfig) throws IOException {
|
Configuration yarnConfig) throws IOException {
|
||||||
super(hadoopEnvSetup, taskType, component, parameters);
|
super(hadoopEnvSetup, component, parameters,
|
||||||
this.taskType = taskType;
|
role != null ? role.getName(): "");
|
||||||
|
Objects.requireNonNull(role, "TensorFlowRole must not be null!");
|
||||||
|
this.role = role;
|
||||||
this.name = parameters.getName();
|
this.name = parameters.getName();
|
||||||
this.distributed = parameters.isDistributed();
|
this.distributed = parameters.isDistributed();
|
||||||
this.numberOfWorkers = parameters.getNumWorkers();
|
this.numberOfWorkers = parameters.getNumWorkers();
|
||||||
|
@ -72,7 +76,7 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
|
||||||
// When distributed training is required
|
// When distributed training is required
|
||||||
if (distributed) {
|
if (distributed) {
|
||||||
String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
|
String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
|
||||||
taskType.getComponentName(), numberOfWorkers,
|
role.getComponentName(), numberOfWorkers,
|
||||||
numberOfPS, name,
|
numberOfPS, name,
|
||||||
TensorFlowCommons.getUserName(),
|
TensorFlowCommons.getUserName(),
|
||||||
TensorFlowCommons.getDNSDomain(yarnConfig));
|
TensorFlowCommons.getDNSDomain(yarnConfig));
|
||||||
|
|
|
@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
@ -37,9 +37,10 @@ public class TensorFlowPsLaunchCommand extends TensorFlowLaunchCommand {
|
||||||
private final String launchCommand;
|
private final String launchCommand;
|
||||||
|
|
||||||
public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
|
||||||
TaskType taskType, Component component, RunJobParameters parameters,
|
Role role, Component component,
|
||||||
|
TensorFlowRunJobParameters parameters,
|
||||||
Configuration yarnConfig) throws IOException {
|
Configuration yarnConfig) throws IOException {
|
||||||
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
|
super(hadoopEnvSetup, role, component, parameters, yarnConfig);
|
||||||
this.launchCommand = parameters.getPSLaunchCmd();
|
this.launchCommand = parameters.getPSLaunchCmd();
|
||||||
|
|
||||||
if (StringUtils.isEmpty(this.launchCommand)) {
|
if (StringUtils.isEmpty(this.launchCommand)) {
|
||||||
|
|
|
@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
import org.apache.hadoop.yarn.submarine.common.api.Role;
|
||||||
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
@ -37,10 +37,10 @@ public class TensorFlowWorkerLaunchCommand extends TensorFlowLaunchCommand {
|
||||||
private final String launchCommand;
|
private final String launchCommand;
|
||||||
|
|
||||||
public TensorFlowWorkerLaunchCommand(
|
public TensorFlowWorkerLaunchCommand(
|
||||||
HadoopEnvironmentSetup hadoopEnvSetup, TaskType taskType,
|
HadoopEnvironmentSetup hadoopEnvSetup, Role role,
|
||||||
Component component, RunJobParameters parameters,
|
Component component, TensorFlowRunJobParameters parameters,
|
||||||
Configuration yarnConfig) throws IOException {
|
Configuration yarnConfig) throws IOException {
|
||||||
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig);
|
super(hadoopEnvSetup, role, component, parameters, yarnConfig);
|
||||||
this.launchCommand = parameters.getWorkerLaunchCmd();
|
this.launchCommand = parameters.getWorkerLaunchCmd();
|
||||||
|
|
||||||
if (StringUtils.isEmpty(this.launchCommand)) {
|
if (StringUtils.isEmpty(this.launchCommand)) {
|
||||||
|
|
|
@ -19,13 +19,14 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.compone
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component;
|
import org.apache.hadoop.yarn.service.api.records.Component;
|
||||||
import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
|
import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
|
||||||
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
|
||||||
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
|
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
|
||||||
|
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
|
||||||
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
|
||||||
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
|
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -54,35 +55,38 @@ public class TensorBoardComponent extends AbstractComponent {
|
||||||
public TensorBoardComponent(FileSystemOperations fsOperations,
|
public TensorBoardComponent(FileSystemOperations fsOperations,
|
||||||
RemoteDirectoryManager remoteDirectoryManager,
|
RemoteDirectoryManager remoteDirectoryManager,
|
||||||
RunJobParameters parameters,
|
RunJobParameters parameters,
|
||||||
LaunchCommandFactory launchCommandFactory,
|
TensorFlowLaunchCommandFactory launchCommandFactory,
|
||||||
Configuration yarnConfig) {
|
Configuration yarnConfig) {
|
||||||
super(fsOperations, remoteDirectoryManager, parameters,
|
super(fsOperations, remoteDirectoryManager, parameters,
|
||||||
TaskType.TENSORBOARD, yarnConfig, launchCommandFactory);
|
TensorFlowRole.TENSORBOARD, yarnConfig, launchCommandFactory);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Component createComponent() throws IOException {
|
public Component createComponent() throws IOException {
|
||||||
Objects.requireNonNull(parameters.getTensorboardResource(),
|
TensorFlowRunJobParameters tensorFlowParams =
|
||||||
|
(TensorFlowRunJobParameters) this.parameters;
|
||||||
|
|
||||||
|
Objects.requireNonNull(tensorFlowParams.getTensorboardResource(),
|
||||||
"TensorBoard resource must not be null!");
|
"TensorBoard resource must not be null!");
|
||||||
|
|
||||||
Component component = new Component();
|
Component component = new Component();
|
||||||
component.setName(taskType.getComponentName());
|
component.setName(role.getComponentName());
|
||||||
component.setNumberOfContainers(1L);
|
component.setNumberOfContainers(1L);
|
||||||
component.setRestartPolicy(RestartPolicyEnum.NEVER);
|
component.setRestartPolicy(RestartPolicyEnum.NEVER);
|
||||||
component.setResource(convertYarnResourceToServiceResource(
|
component.setResource(convertYarnResourceToServiceResource(
|
||||||
parameters.getTensorboardResource()));
|
tensorFlowParams.getTensorboardResource()));
|
||||||
|
|
||||||
if (parameters.getTensorboardDockerImage() != null) {
|
if (tensorFlowParams.getTensorboardDockerImage() != null) {
|
||||||
component.setArtifact(
|
component.setArtifact(
|
||||||
getDockerArtifact(parameters.getTensorboardDockerImage()));
|
getDockerArtifact(tensorFlowParams.getTensorboardDockerImage()));
|
||||||
}
|
}
|
||||||
|
|
||||||
addCommonEnvironments(component, taskType);
|
addCommonEnvironments(component, role);
|
||||||
generateLaunchCommand(component);
|
generateLaunchCommand(component);
|
||||||
|
|
||||||
tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
|
tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
|
||||||
parameters.getName(),
|
parameters.getName(),
|
||||||
taskType.getComponentName() + "-" + 0, getUserName(),
|
role.getComponentName() + "-" + 0, getUserName(),
|
||||||
getDNSDomain(yarnConfig), DEFAULT_PORT);
|
getDNSDomain(yarnConfig), DEFAULT_PORT);
|
||||||
LOG.info("Link to tensorboard:" + tensorboardLink);
|
LOG.info("Link to tensorboard:" + tensorboardLink);
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue