SUBMARINE-52. [SUBMARINE-14] Generate Service spec + launch script for single-node PyTorch learning job. Contributed by Szilard Nemeth.

This commit is contained in:
Zhankun Tang 2019-05-10 23:40:17 +08:00
parent 64c7f36ab1
commit 36267b6f7c
118 changed files with 4695 additions and 1209 deletions

View File

@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
ARG PYTHON_VERSION=3.6
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
curl \
vim \
ca-certificates \
libjpeg-dev \
libpng-dev \
wget &&\
rm -rf /var/lib/apt/lists/*
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include cython typing && \
/opt/conda/bin/conda install -y -c pytorch magma-cuda100 && \
/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/bin:$PATH
RUN pip install ninja
# This must be done before pip so that requirements.txt is available
WORKDIR /opt/pytorch
RUN git clone https://github.com/pytorch/pytorch.git
WORKDIR pytorch
RUN git submodule update --init
RUN TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1 7.0+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
pip install -v .
WORKDIR /opt/pytorch
RUN git clone https://github.com/pytorch/vision.git && cd vision && pip install -v .
WORKDIR /
# Install Hadoop
ENV HADOOP_VERSION="3.1.2"
RUN wget https://archive.apache.org/dist/hadoop/common/hadoop-${HADOOP_VERSION}/hadoop-${HADOOP_VERSION}.tar.gz
RUN tar zxf hadoop-${HADOOP_VERSION}.tar.gz
RUN ln -s hadoop-${HADOOP_VERSION} hadoop-current
RUN rm hadoop-${HADOOP_VERSION}.tar.gz
ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64
RUN echo "$LOG_TAG Install java8" && \
apt-get update && \
apt-get install -y --no-install-recommends openjdk-8-jdk && \
apt-get clean && rm -rf /var/lib/apt/lists/*
RUN echo "Install python related packages" && \
pip --no-cache-dir install Pillow h5py ipykernel jupyter matplotlib numpy pandas scipy sklearn && \
python -m ipykernel.kernelspec
# Set the locale to fix bash warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)
RUN apt-get update && apt-get install -y --no-install-recommends locales && \
apt-get clean && rm -rf /var/lib/apt/lists/*
RUN locale-gen en_US.UTF-8
WORKDIR /workspace
RUN chmod -R a+w /workspace

View File

@ -0,0 +1,30 @@
#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
echo "Building base images"
set -e
cd base/ubuntu-16.04
docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu-base:0.0.1
echo "Finished building base images"
cd ../../with-cifar10-models/ubuntu-16.04
docker build . -f Dockerfile.gpu.pytorch_latest -t pytorch-latest-gpu:0.0.1

View File

@ -0,0 +1,354 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -*- coding: utf-8 -*-
"""
Training a Classifier
=====================
This is it. You have seen how to define neural networks, compute loss and make
updates to the weights of the network.
Now you might be thinking,
What about data?
----------------
Generally, when you have to deal with image, text, audio or video data,
you can use standard python packages that load data into a numpy array.
Then you can convert this array into a ``torch.*Tensor``.
- For images, packages such as Pillow, OpenCV are useful
- For audio, packages such as scipy and librosa
- For text, either raw Python or Cython based loading, or NLTK and
SpaCy are useful
Specifically for vision, we have created a package called
``torchvision``, that has data loaders for common datasets such as
Imagenet, CIFAR10, MNIST, etc. and data transformers for images, viz.,
``torchvision.datasets`` and ``torch.utils.data.DataLoader``.
This provides a huge convenience and avoids writing boilerplate code.
For this tutorial, we will use the CIFAR10 dataset.
It has the classes: airplane, automobile, bird, cat, deer,
dog, frog, horse, ship, truck. The images in CIFAR-10 are of
size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.
.. figure:: /_static/img/cifar10.png
:alt: cifar10
cifar10
Training an image classifier
----------------------------
We will do the following steps in order:
1. Load and normalizing the CIFAR10 training and test datasets using
``torchvision``
2. Define a Convolutional Neural Network
3. Define a loss function
4. Train the network on the training data
5. Test the network on the test data
1. Loading and normalizing CIFAR10
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Using ``torchvision``, its extremely easy to load CIFAR10.
"""
import torch
import torchvision
import torchvision.transforms as transforms
########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
########################################################################
# Let us show some of the training images, for fun.
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
########################################################################
# 2. Define a Convolutional Neural Network
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Copy the neural network from the Neural Networks section before and modify it to
# take 3-channel images (instead of 1-channel images as it was defined).
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
########################################################################
# 3. Define a Loss function and optimizer
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Let's use a Classification Cross-Entropy loss and SGD with momentum.
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
########################################################################
# 4. Train the network
# ^^^^^^^^^^^^^^^^^^^^
#
# This is when things start to get interesting.
# We simply have to loop over our data iterator, and feed the inputs to the
# network and optimize.
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
########################################################################
# 5. Test the network on the test data
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We have trained the network for 2 passes over the training dataset.
# But we need to check if the network has learnt anything at all.
#
# We will check this by predicting the class label that the neural network
# outputs, and checking it against the ground-truth. If the prediction is
# correct, we add the sample to the list of correct predictions.
#
# Okay, first step. Let us display an image from the test set to get familiar.
dataiter = iter(testloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
########################################################################
# Okay, now let us see what the neural network thinks these examples above are:
outputs = net(images)
########################################################################
# The outputs are energies for the 10 classes.
# The higher the energy for a class, the more the network
# thinks that the image is of the particular class.
# So, let's get the index of the highest energy:
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
for j in range(4)))
########################################################################
# The results seem pretty good.
#
# Let us look at how the network performs on the whole dataset.
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
########################################################################
# That looks waaay better than chance, which is 10% accuracy (randomly picking
# a class out of 10 classes).
# Seems like the network learnt something.
#
# Hmmm, what are the classes that performed well, and the classes that did
# not perform well:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
########################################################################
# Okay, so what next?
#
# How do we run these neural networks on the GPU?
#
# Training on GPU
# ----------------
# Just like how you transfer a Tensor onto the GPU, you transfer the neural
# net onto the GPU.
#
# Let's first define our device as the first visible cuda device if we have
# CUDA available:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
########################################################################
# The rest of this section assumes that ``device`` is a CUDA device.
#
# Then these methods will recursively go over all modules and convert their
# parameters and buffers to CUDA tensors:
#
# .. code:: python
#
# net.to(device)
#
#
# Remember that you will have to send the inputs and targets at every step
# to the GPU too:
#
# .. code:: python
#
# inputs, labels = inputs.to(device), labels.to(device)
#
# Why dont I notice MASSIVE speedup compared to CPU? Because your network
# is realllly small.
#
# **Exercise:** Try increasing the width of your network (argument 2 of
# the first ``nn.Conv2d``, and argument 1 of the second ``nn.Conv2d``
# they need to be the same number), see what kind of speedup you get.
#
# **Goals achieved**:
#
# - Understanding PyTorch's Tensor library and neural networks at a high level.
# - Train a small neural network to classify images
#
# Training on multiple GPUs
# -------------------------
# If you want to see even more MASSIVE speedup using all of your GPUs,
# please check out :doc:`data_parallel_tutorial`.
#
# Where do I go next?
# -------------------
#
# - :doc:`Train neural nets to play video games </intermediate/reinforcement_q_learning>`
# - `Train a state-of-the-art ResNet network on imagenet`_
# - `Train a face generator using Generative Adversarial Networks`_
# - `Train a word-level language model using Recurrent LSTM networks`_
# - `More examples`_
# - `More tutorials`_
# - `Discuss PyTorch on the Forums`_
# - `Chat with other users on Slack`_
#
# .. _Train a state-of-the-art ResNet network on imagenet: https://github.com/pytorch/examples/tree/master/imagenet
# .. _Train a face generator using Generative Adversarial Networks: https://github.com/pytorch/examples/tree/master/dcgan
# .. _Train a word-level language model using Recurrent LSTM networks: https://github.com/pytorch/examples/tree/master/word_language_model
# .. _More examples: https://github.com/pytorch/examples
# .. _More tutorials: https://github.com/pytorch/tutorials
# .. _Discuss PyTorch on the Forums: https://discuss.pytorch.org/
# .. _Chat with other users on Slack: https://pytorch.slack.com/messages/beginner/
# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%
del dataiter
# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%

View File

@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
FROM pytorch-latest-gpu-base:0.0.1
RUN mkdir -p /test/data
RUN chmod -R 777 /test
ADD cifar10_tutorial.py /test/cifar10_tutorial.py

View File

@ -15,6 +15,7 @@
package org.apache.hadoop.yarn.submarine.client.cli; package org.apache.hadoop.yarn.submarine.client.cli;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.ClientContext; import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.conf.YarnConfiguration; import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory; import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;

View File

@ -59,4 +59,6 @@ public class CliConstants {
public static final String DISTRIBUTE_KEYTAB = "distribute_keytab"; public static final String DISTRIBUTE_KEYTAB = "distribute_keytab";
public static final String YAML_CONFIG = "f"; public static final String YAML_CONFIG = "f";
public static final String INSECURE_CLUSTER = "insecure"; public static final String INSECURE_CLUSTER = "insecure";
public static final String FRAMEWORK = "framework";
} }

View File

@ -16,7 +16,7 @@ package org.apache.hadoop.yarn.submarine.client.cli;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException; import org.apache.hadoop.yarn.submarine.common.exception.SubmarineRuntimeException;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.slf4j.Logger; import org.slf4j.Logger;

View File

@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli;
/**
* Represents a Submarine command.
*/
public enum Command {
RUN_JOB, SHOW_JOB
}

View File

@ -37,7 +37,7 @@ public class ShowJobCli extends AbstractCli {
private static final Logger LOG = LoggerFactory.getLogger(ShowJobCli.class); private static final Logger LOG = LoggerFactory.getLogger(ShowJobCli.class);
private Options options; private Options options;
private ShowJobParameters parameters = new ShowJobParameters(); private ParametersHolder parametersHolder;
public ShowJobCli(ClientContext cliContext) { public ShowJobCli(ClientContext cliContext) {
super(cliContext); super(cliContext);
@ -62,9 +62,9 @@ public class ShowJobCli extends AbstractCli {
CommandLine cli; CommandLine cli;
try { try {
cli = parser.parse(options, args); cli = parser.parse(options, args);
ParametersHolder parametersHolder = ParametersHolder parametersHolder = ParametersHolder
.createWithCmdLine(cli); .createWithCmdLine(cli, Command.SHOW_JOB);
parameters.updateParameters(parametersHolder, clientContext); parametersHolder.updateParameters(clientContext);
} catch (ParseException e) { } catch (ParseException e) {
printUsages(); printUsages();
} }
@ -97,7 +97,7 @@ public class ShowJobCli extends AbstractCli {
Map<String, String> jobInfo = null; Map<String, String> jobInfo = null;
try { try {
jobInfo = storage.getJobInfoByName(parameters.getName()); jobInfo = storage.getJobInfoByName(getParameters().getName());
} catch (IOException e) { } catch (IOException e) {
LOG.error("Failed to retrieve job info", e); LOG.error("Failed to retrieve job info", e);
throw e; throw e;
@ -108,7 +108,7 @@ public class ShowJobCli extends AbstractCli {
@VisibleForTesting @VisibleForTesting
public ShowJobParameters getParameters() { public ShowJobParameters getParameters() {
return parameters; return (ShowJobParameters) parametersHolder.getParameters();
} }
@Override @Override

View File

@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param;
/**
* Represents the source of configuration.
*/
public enum ConfigType {
YAML, CLI
}

View File

@ -20,8 +20,12 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants; import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.Command;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Roles; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Roles;
@ -29,15 +33,22 @@ import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Scheduling;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Security; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Security;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.TensorBoard; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.TensorBoard;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli.YAML_PARSE_FAILED;
/** /**
* This class acts as a wrapper of {@code CommandLine} values along with * This class acts as a wrapper of {@code CommandLine} values along with
* YAML configuration values. * YAML configuration values.
@ -52,17 +63,110 @@ public final class ParametersHolder {
private static final Logger LOG = private static final Logger LOG =
LoggerFactory.getLogger(ParametersHolder.class); LoggerFactory.getLogger(ParametersHolder.class);
public static final String SUPPORTED_FRAMEWORKS_MESSAGE =
"TensorFlow and PyTorch are the only supported frameworks for now!";
public static final String SUPPORTED_COMMANDS_MESSAGE =
"'Show job' and 'run job' are the only supported commands for now!";
private final CommandLine parsedCommandLine; private final CommandLine parsedCommandLine;
private final Map<String, String> yamlStringConfigs; private final Map<String, String> yamlStringConfigs;
private final Map<String, List<String>> yamlListConfigs; private final Map<String, List<String>> yamlListConfigs;
private final ImmutableSet onlyDefinedWithCliArgs = ImmutableSet.of( private final ConfigType configType;
private Command command;
private final Set onlyDefinedWithCliArgs = ImmutableSet.of(
CliConstants.VERBOSE); CliConstants.VERBOSE);
private final Framework framework;
private final BaseParameters parameters;
private ParametersHolder(CommandLine parsedCommandLine, private ParametersHolder(CommandLine parsedCommandLine,
YamlConfigFile yamlConfig) { YamlConfigFile yamlConfig, ConfigType configType, Command command)
throws ParseException, YarnException {
this.parsedCommandLine = parsedCommandLine; this.parsedCommandLine = parsedCommandLine;
this.yamlStringConfigs = initStringConfigValues(yamlConfig); this.yamlStringConfigs = initStringConfigValues(yamlConfig);
this.yamlListConfigs = initListConfigValues(yamlConfig); this.yamlListConfigs = initListConfigValues(yamlConfig);
this.configType = configType;
this.command = command;
this.framework = determineFrameworkType();
this.ensureOnlyValidSectionsAreDefined(yamlConfig);
this.parameters = createParameters();
}
private BaseParameters createParameters() {
if (command == Command.RUN_JOB) {
if (framework == Framework.TENSORFLOW) {
return new TensorFlowRunJobParameters();
} else if (framework == Framework.PYTORCH) {
return new PyTorchRunJobParameters();
} else {
throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
}
} else if (command == Command.SHOW_JOB) {
return new ShowJobParameters();
} else {
throw new UnsupportedOperationException(SUPPORTED_COMMANDS_MESSAGE);
}
}
private void ensureOnlyValidSectionsAreDefined(YamlConfigFile yamlConfig) {
if (isCommandRunJob() && isFrameworkPyTorch() &&
isPsSectionDefined(yamlConfig)) {
throw new YamlParseException(
"PS section should not be defined when PyTorch " +
"is the selected framework!");
}
if (isCommandRunJob() && isFrameworkPyTorch() &&
isTensorboardSectionDefined(yamlConfig)) {
throw new YamlParseException(
"TensorBoard section should not be defined when PyTorch " +
"is the selected framework!");
}
}
private boolean isCommandRunJob() {
return command == Command.RUN_JOB;
}
private boolean isFrameworkPyTorch() {
return framework == Framework.PYTORCH;
}
private boolean isPsSectionDefined(YamlConfigFile yamlConfig) {
return yamlConfig != null &&
yamlConfig.getRoles() != null &&
yamlConfig.getRoles().getPs() != null;
}
private boolean isTensorboardSectionDefined(YamlConfigFile yamlConfig) {
return yamlConfig != null &&
yamlConfig.getTensorBoard() != null;
}
private Framework determineFrameworkType()
throws ParseException, YarnException {
if (!isCommandRunJob()) {
return null;
}
String frameworkStr = getOptionValue(CliConstants.FRAMEWORK);
if (frameworkStr == null) {
LOG.info("Framework is not defined in config, falling back to " +
"TensorFlow as a default.");
return Framework.TENSORFLOW;
}
Framework framework = Framework.parseByValue(frameworkStr);
if (framework == null) {
if (getConfigType() == ConfigType.CLI) {
throw new ParseException("Failed to parse Framework type! "
+ "Valid values are: " + Framework.getValues());
} else {
throw new YamlParseException(YAML_PARSE_FAILED +
", framework should is defined, but it has an invalid value! " +
"Valid values are: " + Framework.getValues());
}
}
return framework;
} }
/** /**
@ -108,6 +212,8 @@ public final class ParametersHolder {
private void initGenericConfigs(YamlConfigFile yamlConfig, private void initGenericConfigs(YamlConfigFile yamlConfig,
Map<String, String> yamlConfigs) { Map<String, String> yamlConfigs) {
yamlConfigs.put(CliConstants.NAME, yamlConfig.getSpec().getName()); yamlConfigs.put(CliConstants.NAME, yamlConfig.getSpec().getName());
yamlConfigs.put(CliConstants.FRAMEWORK,
yamlConfig.getSpec().getFramework());
Configs configs = yamlConfig.getConfigs(); Configs configs = yamlConfig.getConfigs();
yamlConfigs.put(CliConstants.INPUT_PATH, configs.getInputPath()); yamlConfigs.put(CliConstants.INPUT_PATH, configs.getInputPath());
@ -178,13 +284,15 @@ public final class ParametersHolder {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
public static ParametersHolder createWithCmdLine(CommandLine cli) { public static ParametersHolder createWithCmdLine(CommandLine cli,
return new ParametersHolder(cli, null); Command command) throws ParseException, YarnException {
return new ParametersHolder(cli, null, ConfigType.CLI, command);
} }
public static ParametersHolder createWithCmdLineAndYaml(CommandLine cli, public static ParametersHolder createWithCmdLineAndYaml(CommandLine cli,
YamlConfigFile yamlConfig) { YamlConfigFile yamlConfig, Command command) throws ParseException,
return new ParametersHolder(cli, yamlConfig); YarnException {
return new ParametersHolder(cli, yamlConfig, ConfigType.YAML, command);
} }
/** /**
@ -193,7 +301,7 @@ public final class ParametersHolder {
* @param option Name of the config. * @param option Name of the config.
* @return The value of the config * @return The value of the config
*/ */
String getOptionValue(String option) throws YarnException { public String getOptionValue(String option) throws YarnException {
ensureConfigIsDefinedOnce(option, true); ensureConfigIsDefinedOnce(option, true);
if (onlyDefinedWithCliArgs.contains(option) || if (onlyDefinedWithCliArgs.contains(option) ||
parsedCommandLine.hasOption(option)) { parsedCommandLine.hasOption(option)) {
@ -208,7 +316,7 @@ public final class ParametersHolder {
* @param option Name of the config. * @param option Name of the config.
* @return The values of the config * @return The values of the config
*/ */
List<String> getOptionValues(String option) throws YarnException { public List<String> getOptionValues(String option) throws YarnException {
ensureConfigIsDefinedOnce(option, false); ensureConfigIsDefinedOnce(option, false);
if (onlyDefinedWithCliArgs.contains(option) || if (onlyDefinedWithCliArgs.contains(option) ||
parsedCommandLine.hasOption(option)) { parsedCommandLine.hasOption(option)) {
@ -285,7 +393,7 @@ public final class ParametersHolder {
* @return true, if the option is found in the CLI args or in the YAML config, * @return true, if the option is found in the CLI args or in the YAML config,
* false otherwise. * false otherwise.
*/ */
boolean hasOption(String option) { public boolean hasOption(String option) {
if (onlyDefinedWithCliArgs.contains(option)) { if (onlyDefinedWithCliArgs.contains(option)) {
boolean value = parsedCommandLine.hasOption(option); boolean value = parsedCommandLine.hasOption(option);
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
@ -312,4 +420,21 @@ public final class ParametersHolder {
"from YAML configuration.", result, option); "from YAML configuration.", result, option);
return result; return result;
} }
public ConfigType getConfigType() {
return configType;
}
public Framework getFramework() {
return framework;
}
public void updateParameters(ClientContext clientContext)
throws ParseException, YarnException, IOException {
parameters.updateParameters(this, clientContext);
}
public BaseParameters getParameters() {
return parameters;
}
} }

View File

@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
import java.io.IOException;
import java.util.List;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import com.google.common.collect.Lists;
/**
* Parameters for PyTorch job.
*/
public class PyTorchRunJobParameters extends RunJobParameters {
private static final String CANNOT_BE_DEFINED_FOR_PYTORCH =
"cannot be defined for PyTorch jobs!";
@Override
public void updateParameters(ParametersHolder parametersHolder,
ClientContext clientContext)
throws ParseException, IOException, YarnException {
checkArguments(parametersHolder);
super.updateParameters(parametersHolder, clientContext);
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
this.workerParameters =
getWorkerParameters(clientContext, parametersHolder, input);
this.distributed = determineIfDistributed(workerParameters.getReplicas());
executePostOperations(clientContext);
}
private void checkArguments(ParametersHolder parametersHolder)
throws YarnException, ParseException {
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.N_PS));
} else if (parametersHolder.getOptionValue(CliConstants.PS_RES) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.PS_RES));
} else if (parametersHolder
.getOptionValue(CliConstants.PS_DOCKER_IMAGE) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.PS_DOCKER_IMAGE));
} else if (parametersHolder
.getOptionValue(CliConstants.PS_LAUNCH_CMD) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.PS_LAUNCH_CMD));
} else if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.TENSORBOARD));
} else if (parametersHolder
.getOptionValue(CliConstants.TENSORBOARD_RESOURCES) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.TENSORBOARD_RESOURCES));
} else if (parametersHolder
.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE) != null) {
throw new ParseException(getParamCannotBeDefinedErrorMessage(
CliConstants.TENSORBOARD_DOCKER_IMAGE));
}
}
private String getParamCannotBeDefinedErrorMessage(String cliName) {
return String.format(
"Parameter '%s' " + CANNOT_BE_DEFINED_FOR_PYTORCH, cliName);
}
@Override
void executePostOperations(ClientContext clientContext) throws IOException {
// Set default job dir / saved model dir, etc.
setDefaultDirs(clientContext);
replacePatternsInParameters(clientContext);
}
private void replacePatternsInParameters(ClientContext clientContext)
throws IOException {
if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
String afterReplace =
CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
clientContext.getRemoteDirectoryManager());
setWorkerLaunchCmd(afterReplace);
}
}
@Override
public List<String> getLaunchCommands() {
return Lists.newArrayList(getWorkerLaunchCmd());
}
/**
* We only support non-distributed PyTorch integration for now.
* @param nWorkers
* @return
*/
private boolean determineIfDistributed(int nWorkers) {
return false;
}
}

View File

@ -1,3 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** /**
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -12,7 +28,7 @@
* limitations under the License. See accompanying LICENSE file. * limitations under the License. See accompanying LICENSE file.
*/ */
package org.apache.hadoop.yarn.submarine.client.cli.param; package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.CaseFormat; import com.google.common.base.CaseFormat;
@ -21,7 +37,14 @@ import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants; import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils; import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.Localization;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext; import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.util.resource.ResourceUtils; import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.yaml.snakeyaml.introspector.Property; import org.yaml.snakeyaml.introspector.Property;
import org.yaml.snakeyaml.introspector.PropertyUtils; import org.yaml.snakeyaml.introspector.PropertyUtils;
@ -34,27 +57,15 @@ import java.util.List;
/** /**
* Parameters used to run a job * Parameters used to run a job
*/ */
public class RunJobParameters extends RunParameters { public abstract class RunJobParameters extends RunParameters {
private String input; private String input;
private String checkpointPath; private String checkpointPath;
private int numWorkers;
private int numPS;
private Resource workerResource;
private Resource psResource;
private boolean tensorboardEnabled;
private Resource tensorboardResource;
private String tensorboardDockerImage;
private String workerLaunchCmd;
private String psLaunchCmd;
private List<Quicklink> quicklinks = new ArrayList<>(); private List<Quicklink> quicklinks = new ArrayList<>();
private List<Localization> localizations = new ArrayList<>(); private List<Localization> localizations = new ArrayList<>();
private String psDockerImage = null;
private String workerDockerImage = null;
private boolean waitJobFinish = false; private boolean waitJobFinish = false;
private boolean distributed = false; protected boolean distributed = false;
private boolean securityDisabled = false; private boolean securityDisabled = false;
private String keytab; private String keytab;
@ -62,6 +73,9 @@ public class RunJobParameters extends RunParameters {
private boolean distributeKeytab = false; private boolean distributeKeytab = false;
private List<String> confPairs = new ArrayList<>(); private List<String> confPairs = new ArrayList<>();
RoleParameters workerParameters =
RoleParameters.createEmpty(TensorFlowRole.WORKER);
@Override @Override
public void updateParameters(ParametersHolder parametersHolder, public void updateParameters(ParametersHolder parametersHolder,
ClientContext clientContext) ClientContext clientContext)
@ -70,34 +84,6 @@ public class RunJobParameters extends RunParameters {
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH); String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
String jobDir = parametersHolder.getOptionValue( String jobDir = parametersHolder.getOptionValue(
CliConstants.CHECKPOINT_PATH); CliConstants.CHECKPOINT_PATH);
int nWorkers = 1;
if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
nWorkers = Integer.parseInt(
parametersHolder.getOptionValue(CliConstants.N_WORKERS));
// Only check null value.
// Training job shouldn't ignore INPUT_PATH option
// But if nWorkers is 0, INPUT_PATH can be ignored because
// user can only run Tensorboard
if (null == input && 0 != nWorkers) {
throw new ParseException("\"--" + CliConstants.INPUT_PATH +
"\" is absent");
}
}
int nPS = 0;
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
nPS = Integer.parseInt(
parametersHolder.getOptionValue(CliConstants.N_PS));
}
// Check #workers and #ps.
// When distributed training is required
if (nWorkers >= 2 && nPS > 0) {
distributed = true;
} else if (nWorkers <= 1 && nPS > 0) {
throw new ParseException("Only specified one worker but non-zero PS, "
+ "please double check.");
}
if (parametersHolder.hasOption(CliConstants.INSECURE_CLUSTER)) { if (parametersHolder.hasOption(CliConstants.INSECURE_CLUSTER)) {
setSecurityDisabled(true); setSecurityDisabled(true);
@ -109,46 +95,6 @@ public class RunJobParameters extends RunParameters {
CliConstants.PRINCIPAL); CliConstants.PRINCIPAL);
CliUtils.doLoginIfSecure(kerberosKeytab, kerberosPrincipal); CliUtils.doLoginIfSecure(kerberosKeytab, kerberosPrincipal);
workerResource = null;
if (nWorkers > 0) {
String workerResourceStr = parametersHolder.getOptionValue(
CliConstants.WORKER_RES);
if (workerResourceStr == null) {
throw new ParseException(
"--" + CliConstants.WORKER_RES + " is absent.");
}
workerResource = ResourceUtils.createResourceFromString(
workerResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
}
Resource psResource = null;
if (nPS > 0) {
String psResourceStr = parametersHolder.getOptionValue(
CliConstants.PS_RES);
if (psResourceStr == null) {
throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
}
psResource = ResourceUtils.createResourceFromString(psResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
}
boolean tensorboard = false;
if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
tensorboard = true;
String tensorboardResourceStr = parametersHolder.getOptionValue(
CliConstants.TENSORBOARD_RESOURCES);
if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
}
tensorboardResource = ResourceUtils.createResourceFromString(
tensorboardResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
tensorboardDockerImage = parametersHolder.getOptionValue(
CliConstants.TENSORBOARD_DOCKER_IMAGE);
this.setTensorboardResource(tensorboardResource);
}
if (parametersHolder.hasOption(CliConstants.WAIT_JOB_FINISH)) { if (parametersHolder.hasOption(CliConstants.WAIT_JOB_FINISH)) {
this.waitJobFinish = true; this.waitJobFinish = true;
} }
@ -164,16 +110,6 @@ public class RunJobParameters extends RunParameters {
} }
} }
psDockerImage = parametersHolder.getOptionValue(
CliConstants.PS_DOCKER_IMAGE);
workerDockerImage = parametersHolder.getOptionValue(
CliConstants.WORKER_DOCKER_IMAGE);
String workerLaunchCmd = parametersHolder.getOptionValue(
CliConstants.WORKER_LAUNCH_CMD);
String psLaunchCommand = parametersHolder.getOptionValue(
CliConstants.PS_LAUNCH_CMD);
// Localizations // Localizations
List<String> localizationsStr = parametersHolder.getOptionValues( List<String> localizationsStr = parametersHolder.getOptionValues(
CliConstants.LOCALIZATION); CliConstants.LOCALIZATION);
@ -191,10 +127,6 @@ public class RunJobParameters extends RunParameters {
.getOptionValues(CliConstants.ARG_CONF); .getOptionValues(CliConstants.ARG_CONF);
this.setInputPath(input).setCheckpointPath(jobDir) this.setInputPath(input).setCheckpointPath(jobDir)
.setNumPS(nPS).setNumWorkers(nWorkers)
.setPSLaunchCmd(psLaunchCommand).setWorkerLaunchCmd(workerLaunchCmd)
.setPsResource(psResource)
.setTensorboardEnabled(tensorboard)
.setKeytab(kerberosKeytab) .setKeytab(kerberosKeytab)
.setPrincipal(kerberosPrincipal) .setPrincipal(kerberosPrincipal)
.setDistributeKeytab(distributeKerberosKeytab) .setDistributeKeytab(distributeKerberosKeytab)
@ -203,6 +135,39 @@ public class RunJobParameters extends RunParameters {
super.updateParameters(parametersHolder, clientContext); super.updateParameters(parametersHolder, clientContext);
} }
abstract void executePostOperations(ClientContext clientContext)
throws IOException;
void setDefaultDirs(ClientContext clientContext) throws IOException {
// Create directories if needed
String jobDir = getCheckpointPath();
if (jobDir == null) {
jobDir = getJobDir(clientContext);
setCheckpointPath(jobDir);
}
if (getNumWorkers() > 0) {
String savedModelDir = getSavedModelPath();
if (savedModelDir == null) {
savedModelDir = jobDir;
setSavedModelPath(savedModelDir);
}
}
}
private String getJobDir(ClientContext clientContext) throws IOException {
RemoteDirectoryManager rdm = clientContext.getRemoteDirectoryManager();
if (getNumWorkers() > 0) {
return rdm.getJobCheckpointDir(getName(), true).toString();
} else {
// when #workers == 0, it means we only launch TB. In that case,
// point job dir to root dir so all job's metrics will be shown.
return rdm.getUserRootFolder().toString();
}
}
public abstract List<String> getLaunchCommands();
public String getInputPath() { public String getInputPath() {
return input; return input;
} }
@ -221,110 +186,10 @@ public class RunJobParameters extends RunParameters {
return this; return this;
} }
public int getNumWorkers() {
return numWorkers;
}
public RunJobParameters setNumWorkers(int numWorkers) {
this.numWorkers = numWorkers;
return this;
}
public int getNumPS() {
return numPS;
}
public RunJobParameters setNumPS(int numPS) {
this.numPS = numPS;
return this;
}
public Resource getWorkerResource() {
return workerResource;
}
public RunJobParameters setWorkerResource(Resource workerResource) {
this.workerResource = workerResource;
return this;
}
public Resource getPsResource() {
return psResource;
}
public RunJobParameters setPsResource(Resource psResource) {
this.psResource = psResource;
return this;
}
public boolean isTensorboardEnabled() {
return tensorboardEnabled;
}
public RunJobParameters setTensorboardEnabled(boolean tensorboardEnabled) {
this.tensorboardEnabled = tensorboardEnabled;
return this;
}
public String getWorkerLaunchCmd() {
return workerLaunchCmd;
}
public RunJobParameters setWorkerLaunchCmd(String workerLaunchCmd) {
this.workerLaunchCmd = workerLaunchCmd;
return this;
}
public String getPSLaunchCmd() {
return psLaunchCmd;
}
public RunJobParameters setPSLaunchCmd(String psLaunchCmd) {
this.psLaunchCmd = psLaunchCmd;
return this;
}
public boolean isWaitJobFinish() { public boolean isWaitJobFinish() {
return waitJobFinish; return waitJobFinish;
} }
public String getPsDockerImage() {
return psDockerImage;
}
public void setPsDockerImage(String psDockerImage) {
this.psDockerImage = psDockerImage;
}
public String getWorkerDockerImage() {
return workerDockerImage;
}
public void setWorkerDockerImage(String workerDockerImage) {
this.workerDockerImage = workerDockerImage;
}
public boolean isDistributed() {
return distributed;
}
public Resource getTensorboardResource() {
return tensorboardResource;
}
public void setTensorboardResource(Resource tensorboardResource) {
this.tensorboardResource = tensorboardResource;
}
public String getTensorboardDockerImage() {
return tensorboardDockerImage;
}
public void setTensorboardDockerImage(String tensorboardDockerImage) {
this.tensorboardDockerImage = tensorboardDockerImage;
}
public List<Quicklink> getQuicklinks() { public List<Quicklink> getQuicklinks() {
return quicklinks; return quicklinks;
} }
@ -382,6 +247,90 @@ public class RunJobParameters extends RunParameters {
this.distributed = distributed; this.distributed = distributed;
} }
RoleParameters getWorkerParameters(ClientContext clientContext,
ParametersHolder parametersHolder, String input)
throws ParseException, YarnException, IOException {
int nWorkers = getNumberOfWorkers(parametersHolder, input);
Resource workerResource =
determineWorkerResource(parametersHolder, nWorkers, clientContext);
String workerDockerImage =
parametersHolder.getOptionValue(CliConstants.WORKER_DOCKER_IMAGE);
String workerLaunchCmd =
parametersHolder.getOptionValue(CliConstants.WORKER_LAUNCH_CMD);
return new RoleParameters(TensorFlowRole.WORKER, nWorkers,
workerLaunchCmd, workerDockerImage, workerResource);
}
private Resource determineWorkerResource(ParametersHolder parametersHolder,
int nWorkers, ClientContext clientContext)
throws ParseException, YarnException, IOException {
if (nWorkers > 0) {
String workerResourceStr =
parametersHolder.getOptionValue(CliConstants.WORKER_RES);
if (workerResourceStr == null) {
throw new ParseException(
"--" + CliConstants.WORKER_RES + " is absent.");
}
return ResourceUtils.createResourceFromString(workerResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
}
return null;
}
private int getNumberOfWorkers(ParametersHolder parametersHolder,
String input) throws ParseException, YarnException {
int nWorkers = 1;
if (parametersHolder.getOptionValue(CliConstants.N_WORKERS) != null) {
nWorkers = Integer
.parseInt(parametersHolder.getOptionValue(CliConstants.N_WORKERS));
// Only check null value.
// Training job shouldn't ignore INPUT_PATH option
// But if nWorkers is 0, INPUT_PATH can be ignored because
// user can only run Tensorboard
if (null == input && 0 != nWorkers) {
throw new ParseException(
"\"--" + CliConstants.INPUT_PATH + "\" is absent");
}
}
return nWorkers;
}
public String getWorkerLaunchCmd() {
return workerParameters.getLaunchCommand();
}
public void setWorkerLaunchCmd(String launchCmd) {
workerParameters.setLaunchCommand(launchCmd);
}
public int getNumWorkers() {
return workerParameters.getReplicas();
}
public void setNumWorkers(int numWorkers) {
workerParameters.setReplicas(numWorkers);
}
public Resource getWorkerResource() {
return workerParameters.getResource();
}
public void setWorkerResource(Resource resource) {
workerParameters.setResource(resource);
}
public String getWorkerDockerImage() {
return workerParameters.getDockerImage();
}
public void setWorkerDockerImage(String image) {
workerParameters.setDockerImage(image);
}
public boolean isDistributed() {
return distributed;
}
@VisibleForTesting @VisibleForTesting
public static class UnderscoreConverterPropertyUtils extends PropertyUtils { public static class UnderscoreConverterPropertyUtils extends PropertyUtils {
@Override @Override

View File

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;
import com.google.common.collect.Lists;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RoleParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import java.io.IOException;
import java.util.List;
/**
* Parameters for TensorFlow job.
*/
public class TensorFlowRunJobParameters extends RunJobParameters {
private boolean tensorboardEnabled;
private RoleParameters psParameters =
RoleParameters.createEmpty(TensorFlowRole.PS);
private RoleParameters tensorBoardParameters =
RoleParameters.createEmpty(TensorFlowRole.TENSORBOARD);
@Override
public void updateParameters(ParametersHolder parametersHolder,
ClientContext clientContext)
throws ParseException, IOException, YarnException {
super.updateParameters(parametersHolder, clientContext);
String input = parametersHolder.getOptionValue(CliConstants.INPUT_PATH);
this.workerParameters =
getWorkerParameters(clientContext, parametersHolder, input);
this.psParameters = getPSParameters(clientContext, parametersHolder);
this.distributed = determineIfDistributed(workerParameters.getReplicas(),
psParameters.getReplicas());
if (parametersHolder.hasOption(CliConstants.TENSORBOARD)) {
this.tensorboardEnabled = true;
this.tensorBoardParameters =
getTensorBoardParameters(parametersHolder, clientContext);
}
executePostOperations(clientContext);
}
@Override
void executePostOperations(ClientContext clientContext) throws IOException {
// Set default job dir / saved model dir, etc.
setDefaultDirs(clientContext);
replacePatternsInParameters(clientContext);
}
private void replacePatternsInParameters(ClientContext clientContext)
throws IOException {
if (StringUtils.isNotEmpty(getPSLaunchCmd())) {
String afterReplace = CliUtils.replacePatternsInLaunchCommand(
getPSLaunchCmd(), this, clientContext.getRemoteDirectoryManager());
setPSLaunchCmd(afterReplace);
}
if (StringUtils.isNotEmpty(getWorkerLaunchCmd())) {
String afterReplace =
CliUtils.replacePatternsInLaunchCommand(getWorkerLaunchCmd(), this,
clientContext.getRemoteDirectoryManager());
setWorkerLaunchCmd(afterReplace);
}
}
@Override
public List<String> getLaunchCommands() {
return Lists.newArrayList(getWorkerLaunchCmd(), getPSLaunchCmd());
}
private boolean determineIfDistributed(int nWorkers, int nPS)
throws ParseException {
// Check #workers and #ps.
// When distributed training is required
if (nWorkers >= 2 && nPS > 0) {
return true;
} else if (nWorkers <= 1 && nPS > 0) {
throw new ParseException("Only specified one worker but non-zero PS, "
+ "please double check.");
}
return false;
}
private RoleParameters getPSParameters(ClientContext clientContext,
ParametersHolder parametersHolder)
throws YarnException, IOException, ParseException {
int nPS = getNumberOfPS(parametersHolder);
Resource psResource =
determinePSResource(parametersHolder, nPS, clientContext);
String psDockerImage =
parametersHolder.getOptionValue(CliConstants.PS_DOCKER_IMAGE);
String psLaunchCommand =
parametersHolder.getOptionValue(CliConstants.PS_LAUNCH_CMD);
return new RoleParameters(TensorFlowRole.PS, nPS, psLaunchCommand,
psDockerImage, psResource);
}
private Resource determinePSResource(ParametersHolder parametersHolder,
int nPS, ClientContext clientContext)
throws ParseException, YarnException, IOException {
if (nPS > 0) {
String psResourceStr =
parametersHolder.getOptionValue(CliConstants.PS_RES);
if (psResourceStr == null) {
throw new ParseException("--" + CliConstants.PS_RES + " is absent.");
}
return ResourceUtils.createResourceFromString(psResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
}
return null;
}
private int getNumberOfPS(ParametersHolder parametersHolder)
throws YarnException {
int nPS = 0;
if (parametersHolder.getOptionValue(CliConstants.N_PS) != null) {
nPS =
Integer.parseInt(parametersHolder.getOptionValue(CliConstants.N_PS));
}
return nPS;
}
private RoleParameters getTensorBoardParameters(
ParametersHolder parametersHolder, ClientContext clientContext)
throws YarnException, IOException {
String tensorboardResourceStr =
parametersHolder.getOptionValue(CliConstants.TENSORBOARD_RESOURCES);
if (tensorboardResourceStr == null || tensorboardResourceStr.isEmpty()) {
tensorboardResourceStr = CliConstants.TENSORBOARD_DEFAULT_RESOURCES;
}
Resource tensorboardResource =
ResourceUtils.createResourceFromString(tensorboardResourceStr,
clientContext.getOrCreateYarnClient().getResourceTypeInfo());
String tensorboardDockerImage =
parametersHolder.getOptionValue(CliConstants.TENSORBOARD_DOCKER_IMAGE);
return new RoleParameters(TensorFlowRole.TENSORBOARD, 1, null,
tensorboardDockerImage, tensorboardResource);
}
public int getNumPS() {
return psParameters.getReplicas();
}
public void setNumPS(int numPS) {
psParameters.setReplicas(numPS);
}
public Resource getPsResource() {
return psParameters.getResource();
}
public void setPsResource(Resource resource) {
psParameters.setResource(resource);
}
public String getPsDockerImage() {
return psParameters.getDockerImage();
}
public void setPsDockerImage(String image) {
psParameters.setDockerImage(image);
}
public String getPSLaunchCmd() {
return psParameters.getLaunchCommand();
}
public void setPSLaunchCmd(String launchCmd) {
psParameters.setLaunchCommand(launchCmd);
}
public boolean isTensorboardEnabled() {
return tensorboardEnabled;
}
public Resource getTensorboardResource() {
return tensorBoardParameters.getResource();
}
public void setTensorboardResource(Resource resource) {
tensorBoardParameters.setResource(resource);
}
public String getTensorboardDockerImage() {
return tensorBoardParameters.getDockerImage();
}
public void setTensorboardDockerImage(String image) {
tensorBoardParameters.setDockerImage(image);
}
}

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes that hold run job parameters for
* TensorFlow / PyTorch jobs.
*/
package org.apache.hadoop.yarn.submarine.client.cli.param.runjob;

View File

@ -22,6 +22,7 @@ package org.apache.hadoop.yarn.submarine.client.cli.param.yaml;
public class Spec { public class Spec {
private String name; private String name;
private String jobType; private String jobType;
private String framework;
public String getJobType() { public String getJobType() {
return jobType; return jobType;
@ -38,4 +39,12 @@ public class Spec {
public void setName(String name) { public void setName(String name) {
this.name = name; this.name = name;
} }
public String getFramework() {
return framework;
}
public void setFramework(String framework) {
this.framework = framework;
}
} }

View File

@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.stream.Collectors;
/**
* Represents the type of Machine learning framework to work with.
*/
public enum Framework {
TENSORFLOW(Constants.TENSORFLOW_NAME), PYTORCH(Constants.PYTORCH_NAME);
private String value;
Framework(String value) {
this.value = value;
}
public String getValue() {
return value;
}
public static Framework parseByValue(String value) {
for (Framework fw : Framework.values()) {
if (fw.value.equalsIgnoreCase(value)) {
return fw;
}
}
return null;
}
public static String getValues() {
List<String> values = Lists.newArrayList(Framework.values()).stream()
.map(fw -> fw.value).collect(Collectors.toList());
return String.join(",", values);
}
private static class Constants {
static final String TENSORFLOW_NAME = "tensorflow";
static final String PYTORCH_NAME = "pytorch";
}
}

View File

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.submarine.common.api.Role;
/**
* This class encapsulates data related to a particular Role.
* Some examples: TF Worker process, TF PS process or a PyTorch worker process.
*/
public class RoleParameters {
private final Role role;
private int replicas;
private String launchCommand;
private String dockerImage;
private Resource resource;
public RoleParameters(Role role, int replicas,
String launchCommand, String dockerImage, Resource resource) {
this.role = role;
this.replicas = replicas;
this.launchCommand = launchCommand;
this.dockerImage = dockerImage;
this.resource = resource;
}
public static RoleParameters createEmpty(Role role) {
return new RoleParameters(role, 0, null, null, null);
}
public Role getRole() {
return role;
}
public int getReplicas() {
return replicas;
}
public String getLaunchCommand() {
return launchCommand;
}
public void setLaunchCommand(String launchCommand) {
this.launchCommand = launchCommand;
}
public String getDockerImage() {
return dockerImage;
}
public void setDockerImage(String dockerImage) {
this.dockerImage = dockerImage;
}
public Resource getResource() {
return resource;
}
public void setResource(Resource resource) {
this.resource = resource;
}
public void setReplicas(int replicas) {
this.replicas = replicas;
}
}

View File

@ -1,3 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** /**
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -12,7 +28,7 @@
* limitations under the License. See accompanying LICENSE file. * limitations under the License. See accompanying LICENSE file.
*/ */
package org.apache.hadoop.yarn.submarine.client.cli; package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLine;
@ -23,9 +39,13 @@ import org.apache.commons.cli.ParseException;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.AbstractCli;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.CliUtils;
import org.apache.hadoop.yarn.submarine.client.cli.Command;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder; import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
import org.apache.hadoop.yarn.submarine.common.ClientContext; import org.apache.hadoop.yarn.submarine.common.ClientContext;
@ -44,17 +64,25 @@ import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
/**
* This purpose of this class is to handle / parse CLI arguments related to
* the run job Submarine command.
*/
public class RunJobCli extends AbstractCli { public class RunJobCli extends AbstractCli {
private static final Logger LOG = private static final Logger LOG =
LoggerFactory.getLogger(RunJobCli.class); LoggerFactory.getLogger(RunJobCli.class);
private static final String YAML_PARSE_FAILED = "Failed to parse " + private static final String CAN_BE_USED_WITH_TF_PYTORCH =
"Can be used with TensorFlow or PyTorch frameworks.";
private static final String CAN_BE_USED_WITH_TF_ONLY =
"Can only be used with TensorFlow framework.";
public static final String YAML_PARSE_FAILED = "Failed to parse " +
"YAML config"; "YAML config";
private Options options;
private RunJobParameters parameters = new RunJobParameters();
private Options options;
private JobSubmitter jobSubmitter; private JobSubmitter jobSubmitter;
private JobMonitor jobMonitor; private JobMonitor jobMonitor;
private ParametersHolder parametersHolder;
public RunJobCli(ClientContext cliContext) { public RunJobCli(ClientContext cliContext) {
this(cliContext, cliContext.getRuntimeFactory().getJobSubmitterInstance(), this(cliContext, cliContext.getRuntimeFactory().getJobSubmitterInstance(),
@ -62,7 +90,7 @@ public class RunJobCli extends AbstractCli {
} }
@VisibleForTesting @VisibleForTesting
RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter, public RunJobCli(ClientContext cliContext, JobSubmitter jobSubmitter,
JobMonitor jobMonitor) { JobMonitor jobMonitor) {
super(cliContext); super(cliContext);
this.options = generateOptions(); this.options = generateOptions();
@ -78,6 +106,10 @@ public class RunJobCli extends AbstractCli {
Options options = new Options(); Options options = new Options();
options.addOption(CliConstants.YAML_CONFIG, true, options.addOption(CliConstants.YAML_CONFIG, true,
"Config file (in YAML format)"); "Config file (in YAML format)");
options.addOption(CliConstants.FRAMEWORK, true,
String.format("Framework to use. Valid values are: %s! " +
"The default framework is Tensorflow.",
Framework.getValues()));
options.addOption(CliConstants.NAME, true, "Name of the job"); options.addOption(CliConstants.NAME, true, "Name of the job");
options.addOption(CliConstants.INPUT_PATH, true, options.addOption(CliConstants.INPUT_PATH, true,
"Input of the job, could be local or other FS directory"); "Input of the job, could be local or other FS directory");
@ -88,48 +120,22 @@ public class RunJobCli extends AbstractCli {
options.addOption(CliConstants.SAVED_MODEL_PATH, true, options.addOption(CliConstants.SAVED_MODEL_PATH, true,
"Model exported path (savedmodel) of the job, which is needed when " "Model exported path (savedmodel) of the job, which is needed when "
+ "exported model is not placed under ${checkpoint_path}" + "exported model is not placed under ${checkpoint_path}"
+ "could be local or other FS directory. This will be used to serve."); + "could be local or other FS directory. " +
options.addOption(CliConstants.N_WORKERS, true, "This will be used to serve.");
"Number of worker tasks of the job, by default it's 1");
options.addOption(CliConstants.N_PS, true,
"Number of PS tasks of the job, by default it's 0");
options.addOption(CliConstants.WORKER_RES, true,
"Resource of each worker, for example "
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2");
options.addOption(CliConstants.PS_RES, true,
"Resource of each PS, for example "
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2");
options.addOption(CliConstants.DOCKER_IMAGE, true, "Docker image name/tag"); options.addOption(CliConstants.DOCKER_IMAGE, true, "Docker image name/tag");
options.addOption(CliConstants.QUEUE, true, options.addOption(CliConstants.QUEUE, true,
"Name of queue to run the job, by default it uses default queue"); "Name of queue to run the job, by default it uses default queue");
options.addOption(CliConstants.TENSORBOARD, false,
"Should we run TensorBoard" addWorkerOptions(options);
+ " for this job? By default it's disabled"); addPSOptions(options);
options.addOption(CliConstants.TENSORBOARD_RESOURCES, true, addTensorboardOptions(options);
"Specify resources of Tensorboard, by default it is "
+ CliConstants.TENSORBOARD_DEFAULT_RESOURCES);
options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
"Specify Tensorboard docker image. when this is not "
+ "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
+ " as default.");
options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
"Commandline of worker, arguments will be "
+ "directly used to launch the worker");
options.addOption(CliConstants.PS_LAUNCH_CMD, true,
"Commandline of worker, arguments will be "
+ "directly used to launch the PS");
options.addOption(CliConstants.ENV, true, options.addOption(CliConstants.ENV, true,
"Common environment variable of worker/ps"); "Common environment variable of worker/ps");
options.addOption(CliConstants.VERBOSE, false, options.addOption(CliConstants.VERBOSE, false,
"Print verbose log for troubleshooting"); "Print verbose log for troubleshooting");
options.addOption(CliConstants.WAIT_JOB_FINISH, false, options.addOption(CliConstants.WAIT_JOB_FINISH, false,
"Specified when user want to wait the job finish"); "Specified when user want to wait the job finish");
options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
"Specify docker image for PS, when this is not specified, PS uses --"
+ CliConstants.DOCKER_IMAGE + " as default.");
options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
"Specify docker image for WORKER, when this is not specified, WORKER "
+ "uses --" + CliConstants.DOCKER_IMAGE + " as default.");
options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN" options.addOption(CliConstants.QUICKLINK, true, "Specify quicklink so YARN"
+ "web UI shows link to given role instance and port. When " + "web UI shows link to given role instance and port. When "
+ "--tensorboard is specified, quicklink to tensorboard instance will " + "--tensorboard is specified, quicklink to tensorboard instance will "
@ -172,63 +178,97 @@ public class RunJobCli extends AbstractCli {
return options; return options;
} }
private void replacePatternsInParameters() throws IOException { private void addWorkerOptions(Options options) {
if (parameters.getPSLaunchCmd() != null && !parameters.getPSLaunchCmd() options.addOption(CliConstants.N_WORKERS, true,
.isEmpty()) { "Number of worker tasks of the job, by default it's 1." +
String afterReplace = CliUtils.replacePatternsInLaunchCommand( CAN_BE_USED_WITH_TF_PYTORCH);
parameters.getPSLaunchCmd(), parameters, options.addOption(CliConstants.WORKER_DOCKER_IMAGE, true,
clientContext.getRemoteDirectoryManager()); "Specify docker image for WORKER, when this is not specified, WORKER "
parameters.setPSLaunchCmd(afterReplace); + "uses --" + CliConstants.DOCKER_IMAGE + " as default." +
} CAN_BE_USED_WITH_TF_PYTORCH);
options.addOption(CliConstants.WORKER_LAUNCH_CMD, true,
"Commandline of worker, arguments will be "
+ "directly used to launch the worker" +
CAN_BE_USED_WITH_TF_PYTORCH);
options.addOption(CliConstants.WORKER_RES, true,
"Resource of each worker, for example "
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
CAN_BE_USED_WITH_TF_PYTORCH);
}
if (parameters.getWorkerLaunchCmd() != null && !parameters private void addPSOptions(Options options) {
.getWorkerLaunchCmd().isEmpty()) { options.addOption(CliConstants.N_PS, true,
String afterReplace = CliUtils.replacePatternsInLaunchCommand( "Number of PS tasks of the job, by default it's 0. " +
parameters.getWorkerLaunchCmd(), parameters, CAN_BE_USED_WITH_TF_ONLY);
clientContext.getRemoteDirectoryManager()); options.addOption(CliConstants.PS_DOCKER_IMAGE, true,
parameters.setWorkerLaunchCmd(afterReplace); "Specify docker image for PS, when this is not specified, PS uses --"
} + CliConstants.DOCKER_IMAGE + " as default." +
CAN_BE_USED_WITH_TF_ONLY);
options.addOption(CliConstants.PS_LAUNCH_CMD, true,
"Commandline of worker, arguments will be "
+ "directly used to launch the PS" +
CAN_BE_USED_WITH_TF_ONLY);
options.addOption(CliConstants.PS_RES, true,
"Resource of each PS, for example "
+ "memory-mb=2048,vcores=2,yarn.io/gpu=2" +
CAN_BE_USED_WITH_TF_ONLY);
}
private void addTensorboardOptions(Options options) {
options.addOption(CliConstants.TENSORBOARD, false,
"Should we run TensorBoard"
+ " for this job? By default it's disabled." +
CAN_BE_USED_WITH_TF_ONLY);
options.addOption(CliConstants.TENSORBOARD_RESOURCES, true,
"Specify resources of Tensorboard, by default it is "
+ CliConstants.TENSORBOARD_DEFAULT_RESOURCES + "." +
CAN_BE_USED_WITH_TF_ONLY);
options.addOption(CliConstants.TENSORBOARD_DOCKER_IMAGE, true,
"Specify Tensorboard docker image. when this is not "
+ "specified, Tensorboard " + "uses --" + CliConstants.DOCKER_IMAGE
+ " as default." +
CAN_BE_USED_WITH_TF_ONLY);
} }
private void parseCommandLineAndGetRunJobParameters(String[] args) private void parseCommandLineAndGetRunJobParameters(String[] args)
throws ParseException, IOException, YarnException { throws ParseException, IOException, YarnException {
try { try {
// Do parsing
GnuParser parser = new GnuParser(); GnuParser parser = new GnuParser();
CommandLine cli = parser.parse(options, args); CommandLine cli = parser.parse(options, args);
ParametersHolder parametersHolder = createParametersHolder(cli); parametersHolder = createParametersHolder(cli);
parameters.updateParameters(parametersHolder, clientContext); parametersHolder.updateParameters(clientContext);
} catch (ParseException e) { } catch (ParseException e) {
LOG.error("Exception in parse: {}", e.getMessage()); LOG.error("Exception in parse: {}", e.getMessage());
printUsages(); printUsages();
throw e; throw e;
} }
// Set default job dir / saved model dir, etc.
setDefaultDirs();
// replace patterns
replacePatternsInParameters();
} }
private ParametersHolder createParametersHolder(CommandLine cli) { private ParametersHolder createParametersHolder(CommandLine cli)
throws ParseException, YarnException {
String yamlConfigFile = String yamlConfigFile =
cli.getOptionValue(CliConstants.YAML_CONFIG); cli.getOptionValue(CliConstants.YAML_CONFIG);
if (yamlConfigFile != null) { if (yamlConfigFile != null) {
YamlConfigFile yamlConfig = readYamlConfigFile(yamlConfigFile); YamlConfigFile yamlConfig = readYamlConfigFile(yamlConfigFile);
if (yamlConfig == null) { checkYamlConfig(yamlConfigFile, yamlConfig);
throw new YamlParseException(String.format(
YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
} else if (yamlConfig.getConfigs() == null) {
throw new YamlParseException(String.format(YAML_PARSE_FAILED +
", config section should be defined, but it cannot be found in " +
"YAML file '%s'!", yamlConfigFile));
}
LOG.info("Using YAML configuration!"); LOG.info("Using YAML configuration!");
return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig); return ParametersHolder.createWithCmdLineAndYaml(cli, yamlConfig,
Command.RUN_JOB);
} else { } else {
LOG.info("Using CLI configuration!"); LOG.info("Using CLI configuration!");
return ParametersHolder.createWithCmdLine(cli); return ParametersHolder.createWithCmdLine(cli, Command.RUN_JOB);
}
}
private void checkYamlConfig(String yamlConfigFile,
YamlConfigFile yamlConfig) {
if (yamlConfig == null) {
throw new YamlParseException(String.format(
YAML_PARSE_FAILED + ", file is empty: %s", yamlConfigFile));
} else if (yamlConfig.getConfigs() == null) {
throw new YamlParseException(String.format(YAML_PARSE_FAILED +
", config section should be defined, but it cannot be found in " +
"YAML file '%s'!", yamlConfigFile));
} }
} }
@ -256,34 +296,9 @@ public class RunJobCli extends AbstractCli {
e); e);
} }
private void setDefaultDirs() throws IOException { private void storeJobInformation(RunJobParameters parameters,
// Create directories if needed ApplicationId applicationId, String[] args) throws IOException {
String jobDir = parameters.getCheckpointPath(); String jobName = parameters.getName();
if (null == jobDir) {
if (parameters.getNumWorkers() > 0) {
jobDir = clientContext.getRemoteDirectoryManager().getJobCheckpointDir(
parameters.getName(), true).toString();
} else {
// when #workers == 0, it means we only launch TB. In that case,
// point job dir to root dir so all job's metrics will be shown.
jobDir = clientContext.getRemoteDirectoryManager().getUserRootFolder()
.toString();
}
parameters.setCheckpointPath(jobDir);
}
if (parameters.getNumWorkers() > 0) {
// Only do this when #worker > 0
String savedModelDir = parameters.getSavedModelPath();
if (null == savedModelDir) {
savedModelDir = jobDir;
parameters.setSavedModelPath(savedModelDir);
}
}
}
private void storeJobInformation(String jobName, ApplicationId applicationId,
String[] args) throws IOException {
Map<String, String> jobInfo = new HashMap<>(); Map<String, String> jobInfo = new HashMap<>();
jobInfo.put(StorageKeyConstants.JOB_NAME, jobName); jobInfo.put(StorageKeyConstants.JOB_NAME, jobName);
jobInfo.put(StorageKeyConstants.APPLICATION_ID, applicationId.toString()); jobInfo.put(StorageKeyConstants.APPLICATION_ID, applicationId.toString());
@ -316,8 +331,10 @@ public class RunJobCli extends AbstractCli {
} }
parseCommandLineAndGetRunJobParameters(args); parseCommandLineAndGetRunJobParameters(args);
ApplicationId applicationId = this.jobSubmitter.submitJob(parameters); ApplicationId applicationId = jobSubmitter.submitJob(parametersHolder);
storeJobInformation(parameters.getName(), applicationId, args); RunJobParameters parameters =
(RunJobParameters) parametersHolder.getParameters();
storeJobInformation(parameters, applicationId, args);
if (parameters.isWaitJobFinish()) { if (parameters.isWaitJobFinish()) {
this.jobMonitor.waitTrainingFinal(parameters.getName()); this.jobMonitor.waitTrainingFinal(parameters.getName());
} }
@ -332,6 +349,6 @@ public class RunJobCli extends AbstractCli {
@VisibleForTesting @VisibleForTesting
public RunJobParameters getRunJobParameters() { public RunJobParameters getRunJobParameters() {
return parameters; return (RunJobParameters) parametersHolder.getParameters();
} }
} }

View File

@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes that are related to the run job command.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;

View File

@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. See accompanying LICENSE file.
*/
package org.apache.hadoop.yarn.submarine.common.api;
/**
* Enum to represent a PyTorch Role.
*/
public enum PyTorchRole implements Role {
PRIMARY_WORKER("master"),
WORKER("worker");
private String compName;
PyTorchRole(String compName) {
this.compName = compName;
}
public String getComponentName() {
return compName;
}
@Override
public String getName() {
return name();
}
}

View File

@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.common.api;
/**
* Interface for a Role.
*/
public interface Role {
String getComponentName();
String getName();
}

View File

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.common.api;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.stream.Collectors;
/**
* Represents the type of Runtime.
*/
public enum Runtime {
TONY(Constants.TONY), YARN_SERVICE(Constants.YARN_SERVICE);
private String value;
Runtime(String value) {
this.value = value;
}
public String getValue() {
return value;
}
public static Runtime parseByValue(String value) {
for (Runtime rt : Runtime.values()) {
if (rt.value.equalsIgnoreCase(value)) {
return rt;
}
}
return null;
}
public static String getValues() {
List<String> values = Lists.newArrayList(Runtime.values()).stream()
.map(rt -> rt.value).collect(Collectors.toList());
return String.join(",", values);
}
public static class Constants {
public static final String TONY = "tony";
public static final String YARN_SERVICE = "yarnservice";
}
}

View File

@ -14,7 +14,10 @@
package org.apache.hadoop.yarn.submarine.common.api; package org.apache.hadoop.yarn.submarine.common.api;
public enum TaskType { /**
* Enum to represent a TensorFlow Role.
*/
public enum TensorFlowRole implements Role {
PRIMARY_WORKER("master"), PRIMARY_WORKER("master"),
WORKER("worker"), WORKER("worker"),
PS("ps"), PS("ps"),
@ -22,11 +25,17 @@ public enum TaskType {
private String compName; private String compName;
TaskType(String compName) { TensorFlowRole(String compName) {
this.compName = compName; this.compName = compName;
} }
@Override
public String getComponentName() { public String getComponentName() {
return compName; return compName;
} }
@Override
public String getName() {
return name();
}
} }

View File

@ -16,21 +16,21 @@ package org.apache.hadoop.yarn.submarine.runtimes.common;
import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import java.io.IOException; import java.io.IOException;
/** /**
* Submit job to cluster master * Submit job to cluster master.
*/ */
public interface JobSubmitter { public interface JobSubmitter {
/** /**
* Submit job to cluster * Submit a job to cluster.
* @param parameters run job parameters * @param parameters run job parameters
* @return applicatioId when successfully submitted * @return applicationId when successfully submitted
* @throws YarnException for issues while contacting YARN daemons * @throws YarnException for issues while contacting YARN daemons
* @throws IOException for other issues. * @throws IOException for other issues.
*/ */
ApplicationId submitJob(RunJobParameters parameters) ApplicationId submitJob(ParametersHolder parameters)
throws IOException, YarnException; throws IOException, YarnException;
} }

View File

@ -40,6 +40,10 @@ More details, please refer to
```$xslt ```$xslt
usage: job run usage: job run
-framework <arg> Framework to use.
Valid values are: tensorflow, pytorch.
The default framework is Tensorflow.
-checkpoint_path <arg> Training output directory of the job, could -checkpoint_path <arg> Training output directory of the job, could
be local or other FS directory. This be local or other FS directory. This
typically includes checkpoint files and typically includes checkpoint files and
@ -130,6 +134,7 @@ For submarine internal configuration, please create a `submarine.xml` which shou
#### Commandline #### Commandline
``` ```
yarn jar path-to/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar job run \ yarn jar path-to/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar job run \
--framework tensorflow \
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \ --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current --name tf-job-001 \ --env DOCKER_HADOOP_HDFS_HOME=/hadoop-current --name tf-job-001 \
--docker_image <your-docker-image> \ --docker_image <your-docker-image> \
@ -163,6 +168,7 @@ See below screenshot:
``` ```
yarn jar hadoop-yarn-applications-submarine-<version>.jar job run \ yarn jar hadoop-yarn-applications-submarine-<version>.jar job run \
--name tf-job-001 --docker_image <your-docker-image> \ --name tf-job-001 --docker_image <your-docker-image> \
--framework tensorflow \
--input_path hdfs://default/dataset/cifar-10-data \ --input_path hdfs://default/dataset/cifar-10-data \
--checkpoint_path hdfs://default/tmp/cifar-10-jobdir \ --checkpoint_path hdfs://default/tmp/cifar-10-jobdir \
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \ --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
@ -208,6 +214,7 @@ After that, you can run ```tensorboard --logdir=<checkpoint-path>``` to view Ten
yarn app -destroy tensorboard-service; \ yarn app -destroy tensorboard-service; \
yarn jar /tmp/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar \ yarn jar /tmp/hadoop-yarn-applications-submarine-3.2.0-SNAPSHOT.jar \
job run --name tensorboard-service --verbose --docker_image <your-docker-image> \ job run --name tensorboard-service --verbose --docker_image <your-docker-image> \
--framework tensorflow \
--env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \ --env DOCKER_JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/jre/ \
--env DOCKER_HADOOP_HDFS_HOME=/hadoop-current \ --env DOCKER_HADOOP_HDFS_HOME=/hadoop-current \
--num_workers 0 --tensorboard --num_workers 0 --tensorboard

View File

@ -1,226 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
import org.apache.hadoop.yarn.util.resource.Resources;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TestRunJobCliParsing {
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Test
public void testPrintHelp() {
MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
JobMonitor mockJobMonitor = mock(JobMonitor.class);
RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
mockJobMonitor);
runJobCli.printUsages();
}
static MockClientContext getMockClientContext()
throws IOException, YarnException {
MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn(
ApplicationId.newInstance(1234L, 1));
JobMonitor mockJobMonitor = mock(JobMonitor.class);
SubmarineStorage storage = mock(SubmarineStorage.class);
RuntimeFactory rtFactory = mock(RuntimeFactory.class);
when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
when(rtFactory.getSubmarineStorage()).thenReturn(storage);
mockClientContext.setRuntimeFactory(rtFactory);
return mockClientContext;
}
@Test
public void testBasicRunJobForDistributedTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
"--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
"--principal", "user/_HOST@domain.com", "--distribute_keytab",
"--verbose" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(jobRunParameters.getNumPS(), 2);
assertEquals(jobRunParameters.getPSLaunchCmd(), "python run-ps.py");
assertEquals(Resources.createResource(4096, 4),
jobRunParameters.getPsResource());
assertEquals(jobRunParameters.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(2048, 2),
jobRunParameters.getWorkerResource());
assertEquals(jobRunParameters.getDockerImageName(),
"tf-docker:1.1.0");
assertEquals(jobRunParameters.getKeytab(),
"/keytab/path");
assertEquals(jobRunParameters.getPrincipal(),
"user/_HOST@domain.com");
Assert.assertTrue(jobRunParameters.isDistributeKeytab());
Assert.assertTrue(SubmarineLogs.isVerbose());
}
@Test
public void testBasicRunJobForSingleNodeTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path", "hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(jobRunParameters.getNumWorkers(), 1);
assertEquals(jobRunParameters.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(4096, 2),
jobRunParameters.getWorkerResource());
Assert.assertTrue(SubmarineLogs.isVerbose());
Assert.assertTrue(jobRunParameters.isWaitJobFinish());
}
@Test
public void testNoInputPathOptionSpecified() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\" is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--checkpoint_path", "hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
/**
* when only run tensorboard, input_path is not needed
* */
@Test
public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
boolean success = true;
try {
runJobCli.run(
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--num_workers", "0", "--tensorboard", "--verbose",
"--tensorboard_resources", "memory=2G,vcores=2",
"--tensorboard_docker_image", "tb_docker_image:001"});
} catch (ParseException e) {
success = false;
}
Assert.assertTrue(success);
}
@Test
public void testJobWithoutName() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage =
"--" + CliConstants.NAME + " is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--docker_image", "tf-docker:1.1.0",
"--num_workers", "0", "--tensorboard", "--verbose",
"--tensorboard_resources", "memory=2G,vcores=2",
"--tensorboard_docker_image", "tb_docker_image:001"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
@Test
public void testLaunchCommandPatternReplace() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py --input=%input_path% " +
"--model_dir=%checkpoint_path% " +
"--export_dir=%saved_model_path%/savedmodel",
"--worker_resources", "memory=2048,vcores=2", "--ps_resources",
"memory=4096,vcores=4", "--tensorboard", "true", "--ps_launch_cmd",
"python run-ps.py --input=%input_path% " +
"--model_dir=%checkpoint_path%/model",
"--verbose" });
assertEquals(
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+ "--export_dir=hdfs://output/savedmodel",
runJobCli.getRunJobParameters().getWorkerLaunchCmd());
assertEquals(
"python run-ps.py --input=hdfs://input --model_dir=hdfs://output/model",
runJobCli.getRunJobParameters().getPSLaunchCmd());
}
}

View File

@ -17,7 +17,7 @@
package org.apache.hadoop.yarn.submarine.client.cli; package org.apache.hadoop.yarn.submarine.client.cli;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters.UnderscoreConverterPropertyUtils; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters.UnderscoreConverterPropertyUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlConfigFile;
import org.yaml.snakeyaml.Yaml; import org.yaml.snakeyaml.Yaml;
import org.yaml.snakeyaml.constructor.Constructor; import org.yaml.snakeyaml.constructor.Constructor;
@ -33,13 +33,13 @@ public final class YamlConfigTestUtils {
private YamlConfigTestUtils() {} private YamlConfigTestUtils() {}
static void deleteFile(File file) { public static void deleteFile(File file) {
if (file != null) { if (file != null) {
file.delete(); file.delete();
} }
} }
static YamlConfigFile readYamlConfigFile(String filename) { public static YamlConfigFile readYamlConfigFile(String filename) {
Constructor constructor = new Constructor(YamlConfigFile.class); Constructor constructor = new Constructor(YamlConfigFile.class);
constructor.setPropertyUtils(new UnderscoreConverterPropertyUtils()); constructor.setPropertyUtils(new UnderscoreConverterPropertyUtils());
Yaml yaml = new Yaml(constructor); Yaml yaml = new Yaml(constructor);
@ -49,7 +49,8 @@ public final class YamlConfigTestUtils {
return yaml.loadAs(inputStream, YamlConfigFile.class); return yaml.loadAs(inputStream, YamlConfigFile.class);
} }
static File createTempFileWithContents(String filename) throws IOException { public static File createTempFileWithContents(String filename)
throws IOException {
InputStream inputStream = YamlConfigTestUtils.class InputStream inputStream = YamlConfigTestUtils.class
.getClassLoader() .getClassLoader()
.getResourceAsStream(filename); .getResourceAsStream(filename);
@ -58,7 +59,7 @@ public final class YamlConfigTestUtils {
return targetFile; return targetFile;
} }
static File createEmptyTempFile() throws IOException { public static File createEmptyTempFile() throws IOException {
return File.createTempFile("test", ".yaml"); return File.createTempFile("test", ".yaml");
} }

View File

@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import org.apache.commons.cli.MissingArgumentException;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.common.SubmarineStorage;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.IOException;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* This class contains some test methods to test common functionality
* (including TF / PyTorch) of the run job Submarine command.
*/
public class TestRunJobCliParsingCommon {
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Rule
public ExpectedException expectedException = ExpectedException.none();
public static MockClientContext getMockClientContext()
throws IOException, YarnException {
MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
when(mockJobSubmitter.submitJob(any(ParametersHolder.class)))
.thenReturn(ApplicationId.newInstance(1235L, 1));
JobMonitor mockJobMonitor = mock(JobMonitor.class);
SubmarineStorage storage = mock(SubmarineStorage.class);
RuntimeFactory rtFactory = mock(RuntimeFactory.class);
when(rtFactory.getJobSubmitterInstance()).thenReturn(mockJobSubmitter);
when(rtFactory.getJobMonitorInstance()).thenReturn(mockJobMonitor);
when(rtFactory.getSubmarineStorage()).thenReturn(storage);
mockClientContext.setRuntimeFactory(rtFactory);
return mockClientContext;
}
@Test
public void testAbsentFrameworkFallsBackToTensorFlow() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[]{"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
assertTrue("Default Framework should be TensorFlow!",
runJobParameters instanceof TensorFlowRunJobParameters);
}
@Test
public void testEmptyFrameworkOption() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(MissingArgumentException.class);
expectedException.expectMessage("Missing argument for option: framework");
runJobCli.run(
new String[]{"--framework", "--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
}
@Test
public void testInvalidFrameworkOption() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("Failed to parse Framework type");
runJobCli.run(
new String[]{"--framework", "bla", "--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
}
}

View File

@ -0,0 +1,252 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import org.apache.hadoop.yarn.api.records.ResourceInformation;
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
* This class contains some test methods to test common YAML parsing
* functionality (including TF / PyTorch) of the run job Submarine command.
*/
public class TestRunJobCliParsingCommonYaml {
private static final String DIR_NAME = "runjob-common-yaml";
private static final String TF_DIR = "runjob-pytorch-yaml";
private File yamlConfig;
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@After
public void after() {
YamlConfigTestUtils.deleteFile(yamlConfig);
}
@BeforeClass
public static void configureResourceTypes() {
List<ResourceTypeInfo> resTypes = new ArrayList<>(
ResourceUtils.getResourcesTypeInfo());
resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
ResourceUtils.reinitializeResources(resTypes);
}
@Rule
public ExpectedException exception = ExpectedException.none();
@Test
public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
TF_DIR + "/valid-config.yaml");
String[] args = new String[] {"--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"-f", yamlConfig.getAbsolutePath() };
exception.expect(YarnException.class);
exception.expectMessage("defined both with YAML config and with " +
"CLI argument");
runJobCli.run(args);
}
@Test
public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
TF_DIR + "/valid-config.yaml");
String[] args = new String[] {"--name", "my-job",
"--quicklink", "AAA=http://master-0:8321",
"--quicklink", "BBB=http://worker-0:1234",
"-f", yamlConfig.getAbsolutePath()};
exception.expect(YarnException.class);
exception.expectMessage("defined both with YAML config and with " +
"CLI argument");
runJobCli.run(args);
}
@Test
public void testFalseValuesForBooleanFields() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/test-false-values.yaml");
runJobCli.run(
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertFalse(jobRunParameters.isDistributeKeytab());
assertFalse(jobRunParameters.isWaitJobFinish());
assertFalse(tensorFlowParams.isTensorboardEnabled());
}
@Test
public void testWrongIndentation() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/wrong-indentation.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("Failed to parse YAML config, details:");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testWrongFilename() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
exception.expect(YamlParseException.class);
runJobCli.run(
new String[]{"-f", "not-existing", "--verbose"});
}
@Test
public void testEmptyFile() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
exception.expect(YamlParseException.class);
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testNotExistingFile() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
exception.expect(YamlParseException.class);
exception.expectMessage("file does not exist");
runJobCli.run(
new String[]{"-f", "blabla", "--verbose"});
}
@Test
public void testWrongPropertyName() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/wrong-property-name.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("Failed to parse YAML config, details:");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testMissingConfigsSection() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/missing-configs.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("config section should be defined, " +
"but it cannot be found");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testMissingSectionsShouldParsed() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/some-sections-missing.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testAbsentFramework() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/missing-framework.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testEmptyFramework() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/empty-framework.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testInvalidFramework() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/invalid-framework.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("framework should is defined, " +
"but it has an invalid value");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
}

View File

@ -0,0 +1,192 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob;
import com.google.common.collect.Lists;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobMonitor;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
/**
* This class contains some test methods to test common CLI parsing
* functionality (including TF / PyTorch) of the run job Submarine command.
*/
@RunWith(Parameterized.class)
public class TestRunJobCliParsingParameterized {
private final Framework framework;
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Parameterized.Parameters
public static Collection<Object[]> data() {
Collection<Object[]> params = new ArrayList<>();
params.add(new Object[]{Framework.TENSORFLOW });
params.add(new Object[]{Framework.PYTORCH });
return params;
}
public TestRunJobCliParsingParameterized(Framework framework) {
this.framework = framework;
}
private String getFrameworkName() {
return framework.getValue();
}
@Test
public void testPrintHelp() {
MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
JobMonitor mockJobMonitor = mock(JobMonitor.class);
RunJobCli runJobCli = new RunJobCli(mockClientContext, mockJobSubmitter,
mockJobMonitor);
runJobCli.printUsages();
}
@Test
public void testNoInputPathOptionSpecified() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH + "\"" +
" is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--framework", getFrameworkName(),
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--checkpoint_path", "hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--verbose",
"--wait_job_finish"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
@Test
public void testJobWithoutName() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage =
"--" + CliConstants.NAME + " is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--framework", getFrameworkName(),
"--docker_image", "tf-docker:1.1.0",
"--num_workers", "0", "--verbose"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
@Test
public void testLaunchCommandPatternReplace() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
List<String> parameters = Lists.newArrayList("--framework",
getFrameworkName(),
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd", "python run-job.py --input=%input_path% " +
"--model_dir=%checkpoint_path% " +
"--export_dir=%saved_model_path%/savedmodel",
"--worker_resources", "memory=2048,vcores=2");
if (framework == Framework.TENSORFLOW) {
parameters.addAll(Lists.newArrayList(
"--ps_resources", "memory=4096,vcores=4",
"--ps_launch_cmd", "python run-ps.py --input=%input_path% " +
"--model_dir=%checkpoint_path%/model",
"--verbose"));
}
runJobCli.run(parameters.toArray(new String[0]));
RunJobParameters runJobParameters = checkExpectedFrameworkParams(runJobCli);
if (framework == Framework.TENSORFLOW) {
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) runJobParameters;
assertEquals(
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+ "--export_dir=hdfs://output/savedmodel",
tensorFlowParams.getWorkerLaunchCmd());
assertEquals(
"python run-ps.py --input=hdfs://input " +
"--model_dir=hdfs://output/model",
tensorFlowParams.getPSLaunchCmd());
} else if (framework == Framework.PYTORCH) {
PyTorchRunJobParameters pyTorchParameters =
(PyTorchRunJobParameters) runJobParameters;
assertEquals(
"python run-job.py --input=hdfs://input --model_dir=hdfs://output "
+ "--export_dir=hdfs://output/savedmodel",
pyTorchParameters.getWorkerLaunchCmd());
}
}
private RunJobParameters checkExpectedFrameworkParams(RunJobCli runJobCli) {
RunJobParameters runJobParameters = runJobCli.getRunJobParameters();
if (framework == Framework.TENSORFLOW) {
assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
runJobParameters instanceof TensorFlowRunJobParameters);
} else if (framework == Framework.PYTORCH) {
assertTrue(RunJobParameters.class + " must be an instance of " +
PyTorchRunJobParameters.class,
runJobParameters instanceof PyTorchRunJobParameters);
}
return runJobParameters;
}
}

View File

@ -0,0 +1,209 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.Resources;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
* Test class that verifies the correctness of PyTorch
* CLI configuration parsing.
*/
public class TestRunJobCliParsingPyTorch {
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test
public void testBasicRunJobForSingleNodeTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--verbose",
"--wait_job_finish" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class +
" must be an instance of " +
PyTorchRunJobParameters.class,
jobRunParameters instanceof PyTorchRunJobParameters);
PyTorchRunJobParameters pyTorchParams =
(PyTorchRunJobParameters) jobRunParameters;
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(pyTorchParams.getNumWorkers(), 1);
assertEquals(pyTorchParams.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(4096, 2),
pyTorchParams.getWorkerResource());
assertTrue(SubmarineLogs.isVerbose());
assertTrue(jobRunParameters.isWaitJobFinish());
}
@Test
public void testNumPSCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path","hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--num_ps", "2" });
}
@Test
public void testPSResourcesCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=2048M,vcores=2" });
}
@Test
public void testPSDockerImageCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_docker_image", "psDockerImage" });
}
@Test
public void testPSLaunchCommandCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_launch_cmd", "psLaunchCommand" });
}
@Test
public void testTensorboardCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--tensorboard" });
}
@Test
public void testTensorboardResourcesCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--tensorboard_resources", "memory=2048M,vcores=2" });
}
@Test
public void testTensorboardDockerImageCannotBeDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
expectedException.expect(ParseException.class);
expectedException.expectMessage("cannot be defined for PyTorch jobs");
runJobCli.run(
new String[] {"--framework", "pytorch",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3",
"--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--tensorboard_docker_image", "TBDockerImage" });
}
}

View File

@ -0,0 +1,225 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob.pytorch;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.yarn.api.records.ResourceInformation;
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
/**
* Test class that verifies the correctness of PyTorch
* YAML configuration parsing.
*/
public class TestRunJobCliParsingPyTorchYaml {
private static final String OVERRIDDEN_PREFIX = "overridden_";
private static final String DIR_NAME = "runjob-pytorch-yaml";
private File yamlConfig;
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@After
public void after() {
YamlConfigTestUtils.deleteFile(yamlConfig);
}
@BeforeClass
public static void configureResourceTypes() {
List<ResourceTypeInfo> resTypes = new ArrayList<>(
ResourceUtils.getResourcesTypeInfo());
resTypes.add(ResourceTypeInfo.newInstance(ResourceInformation.GPU_URI, ""));
ResourceUtils.reinitializeResources(resTypes);
}
@Rule
public ExpectedException exception = ExpectedException.none();
private void verifyBasicConfigValues(RunJobParameters jobRunParameters) {
verifyBasicConfigValues(jobRunParameters,
ImmutableList.of("env1=env1Value", "env2=env2Value"));
}
private void verifyBasicConfigValues(RunJobParameters jobRunParameters,
List<String> expectedEnvs) {
assertEquals("testInputPath", jobRunParameters.getInputPath());
assertEquals("testCheckpointPath", jobRunParameters.getCheckpointPath());
assertEquals("testDockerImage", jobRunParameters.getDockerImageName());
assertNotNull(jobRunParameters.getLocalizations());
assertEquals(2, jobRunParameters.getLocalizations().size());
assertNotNull(jobRunParameters.getQuicklinks());
assertEquals(2, jobRunParameters.getQuicklinks().size());
assertTrue(SubmarineLogs.isVerbose());
assertTrue(jobRunParameters.isWaitJobFinish());
for (String env : expectedEnvs) {
assertTrue(String.format(
"%s should be in env list of jobRunParameters!", env),
jobRunParameters.getEnvars().contains(env));
}
}
private void verifyWorkerValues(RunJobParameters jobRunParameters,
String prefix) {
assertTrue(RunJobParameters.class + " must be an instance of " +
PyTorchRunJobParameters.class,
jobRunParameters instanceof PyTorchRunJobParameters);
PyTorchRunJobParameters tensorFlowParams =
(PyTorchRunJobParameters) jobRunParameters;
assertEquals(3, tensorFlowParams.getNumWorkers());
assertEquals(prefix + "testLaunchCmdWorker",
tensorFlowParams.getWorkerLaunchCmd());
assertEquals(prefix + "testDockerImageWorker",
tensorFlowParams.getWorkerDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "2").build()),
tensorFlowParams.getWorkerResource());
}
private void verifySecurityValues(RunJobParameters jobRunParameters) {
assertEquals("keytabPath", jobRunParameters.getKeytab());
assertEquals("testPrincipal", jobRunParameters.getPrincipal());
assertTrue(jobRunParameters.isDistributeKeytab());
}
@Test
public void testValidYamlParsing() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/valid-config.yaml");
runJobCli.run(
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters);
verifyWorkerValues(jobRunParameters, "");
verifySecurityValues(jobRunParameters);
}
@Test
public void testRoleOverrides() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/valid-config-with-overrides.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters);
verifyWorkerValues(jobRunParameters, OVERRIDDEN_PREFIX);
verifySecurityValues(jobRunParameters);
}
@Test
public void testMissingPrincipalUnderSecuritySection() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/security-principal-is-missing.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters);
verifyWorkerValues(jobRunParameters, "");
//Verify security values
assertEquals("keytabPath", jobRunParameters.getKeytab());
assertNull("Principal should be null!", jobRunParameters.getPrincipal());
assertTrue(jobRunParameters.isDistributeKeytab());
}
@Test
public void testMissingEnvs() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/envs-are-missing.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters, ImmutableList.of());
verifyWorkerValues(jobRunParameters, "");
verifySecurityValues(jobRunParameters);
}
@Test
public void testInvalidConfigPsSectionIsDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
exception.expect(YamlParseException.class);
exception.expectMessage("PS section should not be defined " +
"when PyTorch is the selected framework");
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/invalid-config-ps-section.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testInvalidConfigTensorboardSectionIsDefined() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
exception.expect(YamlParseException.class);
exception.expectMessage("TensorBoard section should not be defined " +
"when PyTorch is the selected framework");
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/invalid-config-tensorboard-section.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
}

View File

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.yarn.submarine.client.cli.CliConstants;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.Resources;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
* Test class that verifies the correctness of TensorFlow
* CLI configuration parsing.
*/
public class TestRunJobCliParsingTensorFlow {
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test
public void testNoInputPathOptionSpecified() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
String expectedErrorMessage = "\"--" + CliConstants.INPUT_PATH +
"\" is absent";
String actualMessage = "";
try {
runJobCli.run(
new String[]{"--framework", "tensorflow",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--checkpoint_path", "hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish"});
} catch (ParseException e) {
actualMessage = e.getMessage();
e.printStackTrace();
}
assertEquals(expectedErrorMessage, actualMessage);
}
@Test
public void testBasicRunJobForDistributedTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--framework", "tensorflow",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input",
"--checkpoint_path", "hdfs://output",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=4G,vcores=4", "--tensorboard", "true",
"--ps_launch_cmd", "python run-ps.py", "--keytab", "/keytab/path",
"--principal", "user/_HOST@domain.com", "--distribute_keytab",
"--verbose" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class +
" must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(tensorFlowParams.getNumPS(), 2);
assertEquals(tensorFlowParams.getPSLaunchCmd(), "python run-ps.py");
assertEquals(Resources.createResource(4096, 4),
tensorFlowParams.getPsResource());
assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(2048, 2),
tensorFlowParams.getWorkerResource());
assertEquals(jobRunParameters.getDockerImageName(),
"tf-docker:1.1.0");
assertEquals(jobRunParameters.getKeytab(),
"/keytab/path");
assertEquals(jobRunParameters.getPrincipal(),
"user/_HOST@domain.com");
assertTrue(jobRunParameters.isDistributeKeytab());
assertTrue(SubmarineLogs.isVerbose());
}
@Test
public void testBasicRunJobForSingleNodeTraining() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
assertFalse(SubmarineLogs.isVerbose());
runJobCli.run(
new String[] { "--framework", "tensorflow",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--checkpoint_path",
"hdfs://output",
"--num_workers", "1", "--worker_launch_cmd", "python run-job.py",
"--worker_resources", "memory=4g,vcores=2", "--tensorboard",
"true", "--verbose", "--wait_job_finish" });
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class +
" must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertEquals(jobRunParameters.getInputPath(), "hdfs://input");
assertEquals(jobRunParameters.getCheckpointPath(), "hdfs://output");
assertEquals(tensorFlowParams.getNumWorkers(), 1);
assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
"python run-job.py");
assertEquals(Resources.createResource(4096, 2),
tensorFlowParams.getWorkerResource());
assertTrue(SubmarineLogs.isVerbose());
assertTrue(jobRunParameters.isWaitJobFinish());
}
/**
* when only run tensorboard, input_path is not needed
* */
@Test
public void testNoInputPathOptionButOnlyRunTensorboard() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
boolean success = true;
try {
runJobCli.run(
new String[]{"--framework", "tensorflow",
"--name", "my-job", "--docker_image", "tf-docker:1.1.0",
"--num_workers", "0", "--tensorboard", "--verbose",
"--tensorboard_resources", "memory=2G,vcores=2",
"--tensorboard_docker_image", "tb_docker_image:001"});
} catch (ParseException e) {
success = false;
}
assertTrue(success);
}
}

View File

@ -15,16 +15,17 @@
*/ */
package org.apache.hadoop.yarn.submarine.client.cli; package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.hadoop.yarn.api.records.ResourceInformation; import org.apache.hadoop.yarn.api.records.ResourceInformation;
import org.apache.hadoop.yarn.api.records.ResourceTypeInfo; import org.apache.hadoop.yarn.api.records.ResourceTypeInfo;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper; import org.apache.hadoop.yarn.resourcetypes.ResourceTypesTestHelper;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.YamlConfigTestUtils;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.YamlParseException; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.util.resource.ResourceUtils; import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.junit.After; import org.junit.After;
@ -39,19 +40,18 @@ import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.apache.hadoop.yarn.submarine.client.cli.TestRunJobCliParsing.getMockClientContext; import static org.apache.hadoop.yarn.submarine.client.cli.runjob.TestRunJobCliParsingCommon.getMockClientContext;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
/** /**
* Test class that verifies the correctness of YAML configuration parsing. * Test class that verifies the correctness of TF YAML configuration parsing.
*/ */
public class TestRunJobCliParsingYaml { public class TestRunJobCliParsingTensorFlowYaml {
private static final String OVERRIDDEN_PREFIX = "overridden_"; private static final String OVERRIDDEN_PREFIX = "overridden_";
private static final String DIR_NAME = "runjobcliparsing"; private static final String DIR_NAME = "runjob-tensorflow-yaml";
private File yamlConfig; private File yamlConfig;
@Before @Before
@ -104,27 +104,39 @@ public class TestRunJobCliParsingYaml {
private void verifyPsValues(RunJobParameters jobRunParameters, private void verifyPsValues(RunJobParameters jobRunParameters,
String prefix) { String prefix) {
assertEquals(4, jobRunParameters.getNumPS()); assertTrue(RunJobParameters.class + " must be an instance of " +
assertEquals(prefix + "testLaunchCmdPs", jobRunParameters.getPSLaunchCmd()); TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertEquals(4, tensorFlowParams.getNumPS());
assertEquals(prefix + "testLaunchCmdPs", tensorFlowParams.getPSLaunchCmd());
assertEquals(prefix + "testDockerImagePs", assertEquals(prefix + "testDockerImagePs",
jobRunParameters.getPsDockerImage()); tensorFlowParams.getPsDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(20500L, 34, assertEquals(ResourceTypesTestHelper.newResource(20500L, 34,
ImmutableMap.<String, String> builder() ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "4").build()), .put(ResourceInformation.GPU_URI, "4").build()),
jobRunParameters.getPsResource()); tensorFlowParams.getPsResource());
} }
private void verifyWorkerValues(RunJobParameters jobRunParameters, private void verifyWorkerValues(RunJobParameters jobRunParameters,
String prefix) { String prefix) {
assertEquals(3, jobRunParameters.getNumWorkers()); assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertEquals(3, tensorFlowParams.getNumWorkers());
assertEquals(prefix + "testLaunchCmdWorker", assertEquals(prefix + "testLaunchCmdWorker",
jobRunParameters.getWorkerLaunchCmd()); tensorFlowParams.getWorkerLaunchCmd());
assertEquals(prefix + "testDockerImageWorker", assertEquals(prefix + "testDockerImageWorker",
jobRunParameters.getWorkerDockerImage()); tensorFlowParams.getWorkerDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(20480L, 32, assertEquals(ResourceTypesTestHelper.newResource(20480L, 32,
ImmutableMap.<String, String> builder() ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "2").build()), .put(ResourceInformation.GPU_URI, "2").build()),
jobRunParameters.getWorkerResource()); tensorFlowParams.getWorkerResource());
} }
private void verifySecurityValues(RunJobParameters jobRunParameters) { private void verifySecurityValues(RunJobParameters jobRunParameters) {
@ -134,13 +146,19 @@ public class TestRunJobCliParsingYaml {
} }
private void verifyTensorboardValues(RunJobParameters jobRunParameters) { private void verifyTensorboardValues(RunJobParameters jobRunParameters) {
assertTrue(jobRunParameters.isTensorboardEnabled()); assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertTrue(tensorFlowParams.isTensorboardEnabled());
assertEquals("tensorboardDockerImage", assertEquals("tensorboardDockerImage",
jobRunParameters.getTensorboardDockerImage()); tensorFlowParams.getTensorboardDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37, assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
ImmutableMap.<String, String> builder() ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "3").build()), .put(ResourceInformation.GPU_URI, "3").build()),
jobRunParameters.getTensorboardResource()); tensorFlowParams.getTensorboardResource());
} }
@Test @Test
@ -161,44 +179,6 @@ public class TestRunJobCliParsingYaml {
verifyTensorboardValues(jobRunParameters); verifyTensorboardValues(jobRunParameters);
} }
@Test
public void testYamlAndCliOptionIsDefinedIsInvalid() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/valid-config.yaml");
String[] args = new String[] {"--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"-f", yamlConfig.getAbsolutePath() };
exception.expect(YarnException.class);
exception.expectMessage("defined both with YAML config and with " +
"CLI argument");
runJobCli.run(args);
}
@Test
public void testYamlAndCliOptionIsDefinedIsInvalidWithListOption()
throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/valid-config.yaml");
String[] args = new String[] {"--name", "my-job",
"--quicklink", "AAA=http://master-0:8321",
"--quicklink", "BBB=http://worker-0:1234",
"-f", yamlConfig.getAbsolutePath()};
exception.expect(YarnException.class);
exception.expectMessage("defined both with YAML config and with " +
"CLI argument");
runJobCli.run(args);
}
@Test @Test
public void testRoleOverrides() throws Exception { public void testRoleOverrides() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext()); RunJobCli runJobCli = new RunJobCli(getMockClientContext());
@ -217,104 +197,6 @@ public class TestRunJobCliParsingYaml {
verifyTensorboardValues(jobRunParameters); verifyTensorboardValues(jobRunParameters);
} }
@Test
public void testFalseValuesForBooleanFields() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/test-false-values.yaml");
runJobCli.run(
new String[] {"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertFalse(jobRunParameters.isDistributeKeytab());
assertFalse(jobRunParameters.isWaitJobFinish());
assertFalse(jobRunParameters.isTensorboardEnabled());
}
@Test
public void testWrongIndentation() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/wrong-indentation.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("Failed to parse YAML config, details:");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testWrongFilename() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
Assert.assertFalse(SubmarineLogs.isVerbose());
exception.expect(YamlParseException.class);
runJobCli.run(
new String[]{"-f", "not-existing", "--verbose"});
}
@Test
public void testEmptyFile() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createEmptyTempFile();
exception.expect(YamlParseException.class);
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testNotExistingFile() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
exception.expect(YamlParseException.class);
exception.expectMessage("file does not exist");
runJobCli.run(
new String[]{"-f", "blabla", "--verbose"});
}
@Test
public void testWrongPropertyName() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/wrong-property-name.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("Failed to parse YAML config, details:");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testMissingConfigsSection() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/missing-configs.yaml");
exception.expect(YamlParseException.class);
exception.expectMessage("config section should be defined, " +
"but it cannot be found");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test
public void testMissingSectionsShouldParsed() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext());
yamlConfig = YamlConfigTestUtils.createTempFileWithContents(
DIR_NAME + "/some-sections-missing.yaml");
runJobCli.run(
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
}
@Test @Test
public void testMissingPrincipalUnderSecuritySection() throws Exception { public void testMissingPrincipalUnderSecuritySection() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext()); RunJobCli runJobCli = new RunJobCli(getMockClientContext());
@ -346,18 +228,22 @@ public class TestRunJobCliParsingYaml {
new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"}); new String[]{"-f", yamlConfig.getAbsolutePath(), "--verbose"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters(); RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
verifyBasicConfigValues(jobRunParameters); verifyBasicConfigValues(jobRunParameters);
verifyPsValues(jobRunParameters, ""); verifyPsValues(jobRunParameters, "");
verifyWorkerValues(jobRunParameters, ""); verifyWorkerValues(jobRunParameters, "");
verifySecurityValues(jobRunParameters); verifySecurityValues(jobRunParameters);
assertTrue(jobRunParameters.isTensorboardEnabled()); TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
assertTrue(tensorFlowParams.isTensorboardEnabled());
assertNull("tensorboardDockerImage should be null!", assertNull("tensorboardDockerImage should be null!",
jobRunParameters.getTensorboardDockerImage()); tensorFlowParams.getTensorboardDockerImage());
assertEquals(ResourceTypesTestHelper.newResource(21000L, 37, assertEquals(ResourceTypesTestHelper.newResource(21000L, 37,
ImmutableMap.<String, String> builder() ImmutableMap.<String, String> builder()
.put(ResourceInformation.GPU_URI, "3").build()), .put(ResourceInformation.GPU_URI, "3").build()),
jobRunParameters.getTensorboardResource()); tensorFlowParams.getTensorboardResource());
} }
@Test @Test

View File

@ -15,7 +15,7 @@
*/ */
package org.apache.hadoop.yarn.submarine.client.cli; package org.apache.hadoop.yarn.submarine.client.cli.runjob.tensorflow;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Configs;
import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role; import org.apache.hadoop.yarn.submarine.client.cli.param.yaml.Role;
@ -42,14 +42,9 @@ import static org.junit.Assert.assertTrue;
* Please note that this class just tests YAML parsing, * Please note that this class just tests YAML parsing,
* but only in an isolated fashion. * but only in an isolated fashion.
*/ */
public class TestRunJobCliParsingYamlStandalone { public class TestRunJobCliParsingTensorFlowYamlStandalone {
private static final String OVERRIDDEN_PREFIX = "overridden_"; private static final String OVERRIDDEN_PREFIX = "overridden_";
private static final String DIR_NAME = "runjobcliparsing"; private static final String DIR_NAME = "runjob-tensorflow-yaml";
@Before
public void before() {
SubmarineLogs.verboseOff();
}
private void verifyBasicConfigValues(YamlConfigFile yamlConfigFile) { private void verifyBasicConfigValues(YamlConfigFile yamlConfigFile) {
assertNotNull("Spec file should not be null!", yamlConfigFile); assertNotNull("Spec file should not be null!", yamlConfigFile);
@ -169,6 +164,11 @@ public class TestRunJobCliParsingYamlStandalone {
assertEquals("memory=21000M,vcores=37,gpu=3", tensorBoard.getResources()); assertEquals("memory=21000M,vcores=37,gpu=3", tensorBoard.getResources());
} }
@Before
public void before() {
SubmarineLogs.verboseOff();
}
@Test @Test
public void testLaunchCommandYaml() { public void testLaunchCommandYaml() {
YamlConfigFile yamlConfigFile = readYamlConfigFile(DIR_NAME + YamlConfigFile yamlConfigFile = readYamlConfigFile(DIR_NAME +
@ -201,5 +201,4 @@ public class TestRunJobCliParsingYamlStandalone {
assertRoleConfigOverrides(roles.getWorker(), OVERRIDDEN_PREFIX, "Worker"); assertRoleConfigOverrides(roles.getWorker(), OVERRIDDEN_PREFIX, "Worker");
assertRoleConfigOverrides(roles.getPs(), OVERRIDDEN_PREFIX, "Ps"); assertRoleConfigOverrides(roles.getPs(), OVERRIDDEN_PREFIX, "Ps");
} }
} }

View File

@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework:
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
ps:
resources: memory=20500M,vcores=34,gpu=4
replicas: 4
launch_cmd: testLaunchCmdPs
docker_image: testDockerImagePs
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true
tensorBoard:
resources: memory=21000M,vcores=37,gpu=3
docker_image: tensorboardDockerImage

View File

@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: bla
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
ps:
resources: memory=20500M,vcores=34,gpu=4
replicas: 4
launch_cmd: testLaunchCmdPs
docker_image: testDockerImagePs
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true
tensorBoard:
resources: memory=21000M,vcores=37,gpu=3
docker_image: tensorboardDockerImage

View File

@ -17,6 +17,7 @@
spec: spec:
name: testJobName name: testJobName
job_type: testJobType job_type: testJobType
framework: tensorflow
configs: configs:
input_path: testInputPath input_path: testInputPath

View File

@ -17,6 +17,7 @@
spec: spec:
name: testJobName name: testJobName
job_type: testJobType job_type: testJobType
framework: tensorflow
configs: configs:
input_path: testInputPath input_path: testInputPath

View File

@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true

View File

@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
ps:
docker_image: testPSDockerImage
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true

View File

@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true
tensorBoard:
docker_image: tensorboardDockerImage

View File

@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
security:
keytab: keytabPath
distribute_keytab: true

View File

@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: overridden_testLaunchCmdWorker
docker_image: overridden_testDockerImageWorker
envs:
env1: 'overridden_env1Worker'
env2: 'overridden_env2Worker'
localizations:
- hdfs://remote-file1:/overridden_local-filename1Worker:rw
- nfs://remote-file2:/overridden_local-filename2Worker:rw
mounts:
- /etc/passwd:/overridden_Worker
- /etc/hosts:/overridden_Worker
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true

View File

@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: pytorch
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true

View File

@ -17,6 +17,7 @@
spec: spec:
name: testJobName name: testJobName
job_type: testJobType job_type: testJobType
framework: tensorflow
configs: configs:
input_path: testInputPath input_path: testInputPath

View File

@ -17,6 +17,7 @@
spec: spec:
name: testJobName name: testJobName
job_type: testJobType job_type: testJobType
framework: tensorflow
configs: configs:
input_path: testInputPath input_path: testInputPath

View File

@ -17,6 +17,7 @@
spec: spec:
name: testJobName name: testJobName
job_type: testJobType job_type: testJobType
framework: tensorflow
configs: configs:
input_path: testInputPath input_path: testInputPath

View File

@ -17,6 +17,7 @@
spec: spec:
name: testJobName name: testJobName
job_type: testJobType job_type: testJobType
framework: tensorflow
configs: configs:
input_path: testInputPath input_path: testInputPath

View File

@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
spec:
name: testJobName
job_type: testJobType
framework: tensorflow
configs:
input_path: testInputPath
checkpoint_path: testCheckpointPath
saved_model_path: testSavedModelPath
docker_image: testDockerImage
wait_job_finish: true
envs:
env1: 'env1Value'
env2: 'env2Value'
localizations:
- hdfs://remote-file1:/local-filename1:rw
- nfs://remote-file2:/local-filename2:rw
mounts:
- /etc/passwd:/etc/passwd:rw
- /etc/hosts:/etc/hosts:rw
quicklinks:
- Notebook_UI=https://master-0:7070
- Notebook_UI2=https://master-0:7071
scheduling:
queue: queue1
roles:
worker:
resources: memory=20480M,vcores=32,gpu=2
replicas: 3
launch_cmd: testLaunchCmdWorker
docker_image: testDockerImageWorker
ps:
resources: memory=20500M,vcores=34,gpu=4
replicas: 4
launch_cmd: testLaunchCmdPs
docker_image: testDockerImagePs
security:
keytab: keytabPath
principal: testPrincipal
distribute_keytab: true
tensorBoard:
resources: memory=21000M,vcores=37,gpu=3
docker_image: tensorboardDockerImage

View File

@ -22,7 +22,9 @@ import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter; import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import java.io.File; import java.io.File;
@ -45,14 +47,24 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
} }
@Override @Override
public ApplicationId submitJob(RunJobParameters parameters) public ApplicationId submitJob(ParametersHolder parameters)
throws IOException, YarnException { throws IOException {
if (parameters.getFramework() == Framework.PYTORCH) {
// we need to throw an exception, as ParametersHolder's parameters field
// could not be casted to TensorFlowRunJobParameters.
throw new UnsupportedOperationException(
"Support \"-framework\" option for PyTorch in Tony is coming. " +
"Please check the documentation about how to submit a " +
"PyTorch job with TonY runtime.");
}
LOG.info("Starting Tony runtime.."); LOG.info("Starting Tony runtime..");
File tonyFinalConfPath = File.createTempFile("temp", File tonyFinalConfPath = File.createTempFile("temp",
Constants.TONY_FINAL_XML); Constants.TONY_FINAL_XML);
// Write user's overridden conf to an xml to be localized. // Write user's overridden conf to an xml to be localized.
Configuration tonyConf = TonyUtils.tonyConfFromClientContext(parameters); Configuration tonyConf = TonyUtils.tonyConfFromClientContext(
(TensorFlowRunJobParameters) parameters.getParameters());
try (OutputStream os = new FileOutputStream(tonyFinalConfPath)) { try (OutputStream os = new FileOutputStream(tonyFinalConfPath)) {
tonyConf.writeXml(os); tonyConf.writeXml(os);
} catch (IOException e) { } catch (IOException e) {
@ -68,7 +80,7 @@ public class TonyJobSubmitter implements JobSubmitter, CallbackHandler {
LOG.error("Failed to init TonyClient: ", e); LOG.error("Failed to init TonyClient: ", e);
} }
Thread clientThread = new Thread(tonyClient::start); Thread clientThread = new Thread(tonyClient::start);
Runtime.getRuntime().addShutdownHook(new Thread(() -> { java.lang.Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try { try {
tonyClient.forceKillApplication(); tonyClient.forceKillApplication();
} catch (YarnException | IOException e) { } catch (YarnException | IOException e) {

View File

@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ResourceInformation; import org.apache.hadoop.yarn.api.records.ResourceInformation;
import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException; import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -35,7 +35,7 @@ public final class TonyUtils {
private static final Log LOG = LogFactory.getLog(TonyUtils.class); private static final Log LOG = LogFactory.getLog(TonyUtils.class);
public static Configuration tonyConfFromClientContext( public static Configuration tonyConfFromClientContext(
RunJobParameters parameters) { TensorFlowRunJobParameters parameters) {
Configuration tonyConf = new Configuration(); Configuration tonyConf = new Configuration();
tonyConf.setInt( tonyConf.setInt(
TonyConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME), TonyConfigurationKeys.getInstancesKey(Constants.WORKER_JOB_NAME),

View File

@ -147,6 +147,7 @@ CLASSPATH=$(hadoop classpath --glob): \
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \ /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \ java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
--framework tensorflow \
--num_workers 2 \ --num_workers 2 \
--worker_resources memory=3G,vcores=2 \ --worker_resources memory=3G,vcores=2 \
--num_ps 2 \ --num_ps 2 \
@ -183,6 +184,7 @@ CLASSPATH=$(hadoop classpath --glob): \
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \ /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \ java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
--framework tensorflow \
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \ --docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \ --input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
--worker_resources memory=3G,vcores=2 \ --worker_resources memory=3G,vcores=2 \
@ -245,6 +247,7 @@ CLASSPATH=$(hadoop classpath --glob): \
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \ /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \ java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
--framework tensorflow \
--num_workers 2 \ --num_workers 2 \
--worker_resources memory=3G,vcores=2 \ --worker_resources memory=3G,vcores=2 \
--num_ps 2 \ --num_ps 2 \
@ -281,6 +284,7 @@ CLASSPATH=$(hadoop classpath --glob): \
/home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \ /home/pi/hadoop/TonY/tony-cli/build/libs/tony-cli-0.3.2-all.jar \
java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \ java org.apache.hadoop.yarn.submarine.client.cli.Cli job run --name tf-job-001 \
--framework tensorflow \
--docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \ --docker_image hadoopsubmarine/tf-1.8.0-cpu:0.0.3 \
--input_path hdfs://pi-aw:9000/dataset/cifar-10-data \ --input_path hdfs://pi-aw:9000/dataset/cifar-10-data \
--worker_resources memory=3G,vcores=2 \ --worker_resources memory=3G,vcores=2 \

View File

@ -16,8 +16,10 @@ import com.linkedin.tony.TonyConfigurationKeys;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.submarine.client.cli.RunJobCli; import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.RunJobCli;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.MockClientContext; import org.apache.hadoop.yarn.submarine.common.MockClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory; import org.apache.hadoop.yarn.submarine.runtimes.RuntimeFactory;
@ -31,6 +33,7 @@ import org.junit.Test;
import java.io.IOException; import java.io.IOException;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -59,7 +62,8 @@ public class TestTonyUtils {
throws IOException, YarnException { throws IOException, YarnException {
MockClientContext mockClientContext = new MockClientContext(); MockClientContext mockClientContext = new MockClientContext();
JobSubmitter mockJobSubmitter = mock(JobSubmitter.class); JobSubmitter mockJobSubmitter = mock(JobSubmitter.class);
when(mockJobSubmitter.submitJob(any(RunJobParameters.class))).thenReturn( when(mockJobSubmitter.submitJob(
any(ParametersHolder.class))).thenReturn(
ApplicationId.newInstance(1234L, 1)); ApplicationId.newInstance(1234L, 1));
JobMonitor mockJobMonitor = mock(JobMonitor.class); JobMonitor mockJobMonitor = mock(JobMonitor.class);
SubmarineStorage storage = mock(SubmarineStorage.class); SubmarineStorage storage = mock(SubmarineStorage.class);
@ -82,20 +86,28 @@ public class TestTonyUtils {
public void testTonyConfFromClientContext() throws Exception { public void testTonyConfFromClientContext() throws Exception {
RunJobCli runJobCli = new RunJobCli(getMockClientContext()); RunJobCli runJobCli = new RunJobCli(getMockClientContext());
runJobCli.run( runJobCli.run(
new String[] {"--name", "my-job", "--docker_image", "tf-docker:1.1.0", new String[] {"--framework", "tensorflow", "--name", "my-job",
"--docker_image", "tf-docker:1.1.0",
"--input_path", "hdfs://input", "--input_path", "hdfs://input",
"--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd", "--num_workers", "3", "--num_ps", "2", "--worker_launch_cmd",
"python run-job.py", "--worker_resources", "memory=2048M,vcores=2", "python run-job.py", "--worker_resources", "memory=2048M,vcores=2",
"--ps_resources", "memory=4G,vcores=4", "--ps_launch_cmd", "--ps_resources", "memory=4G,vcores=4", "--ps_launch_cmd",
"python run-ps.py"}); "python run-ps.py"});
RunJobParameters jobRunParameters = runJobCli.getRunJobParameters(); RunJobParameters jobRunParameters = runJobCli.getRunJobParameters();
assertTrue(RunJobParameters.class + " must be an instance of " +
TensorFlowRunJobParameters.class,
jobRunParameters instanceof TensorFlowRunJobParameters);
TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) jobRunParameters;
Configuration tonyConf = TonyUtils Configuration tonyConf = TonyUtils
.tonyConfFromClientContext(jobRunParameters); .tonyConfFromClientContext(tensorFlowParams);
Assert.assertEquals(jobRunParameters.getDockerImageName(), Assert.assertEquals(jobRunParameters.getDockerImageName(),
tonyConf.get(TonyConfigurationKeys.getContainerDockerKey())); tonyConf.get(TonyConfigurationKeys.getContainerDockerKey()));
Assert.assertEquals("3", tonyConf.get(TonyConfigurationKeys Assert.assertEquals("3", tonyConf.get(TonyConfigurationKeys
.getInstancesKey("worker"))); .getInstancesKey("worker")));
Assert.assertEquals(jobRunParameters.getWorkerLaunchCmd(), Assert.assertEquals(tensorFlowParams.getWorkerLaunchCmd(),
tonyConf.get(TonyConfigurationKeys tonyConf.get(TonyConfigurationKeys
.getExecuteCommandKey("worker"))); .getExecuteCommandKey("worker")));
Assert.assertEquals("2048", tonyConf.get(TonyConfigurationKeys Assert.assertEquals("2048", tonyConf.get(TonyConfigurationKeys
@ -107,7 +119,7 @@ public class TestTonyUtils {
Assert.assertEquals("4", tonyConf.get(TonyConfigurationKeys Assert.assertEquals("4", tonyConf.get(TonyConfigurationKeys
.getResourceKey(Constants.PS_JOB_NAME, .getResourceKey(Constants.PS_JOB_NAME,
Constants.VCORES))); Constants.VCORES)));
Assert.assertEquals(jobRunParameters.getPSLaunchCmd(), Assert.assertEquals(tensorFlowParams.getPSLaunchCmd(),
tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey("ps"))); tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey("ps")));
} }
} }

View File

@ -19,8 +19,10 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType; import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
@ -28,7 +30,11 @@ import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchComma
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
import static org.apache.hadoop.yarn.service.conf.YarnServiceConstants.CONTAINER_STATE_REPORT_AS_SERVICE_STATE;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.addCommonEnvironments;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName; import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getScriptFileName;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.SubmarineResourceUtils.convertYarnResourceToServiceResource;
/** /**
* Abstract base class for Component classes. * Abstract base class for Component classes.
@ -40,7 +46,7 @@ import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.T
public abstract class AbstractComponent { public abstract class AbstractComponent {
private final FileSystemOperations fsOperations; private final FileSystemOperations fsOperations;
protected final RunJobParameters parameters; protected final RunJobParameters parameters;
protected final TaskType taskType; protected final Role role;
private final RemoteDirectoryManager remoteDirectoryManager; private final RemoteDirectoryManager remoteDirectoryManager;
protected final Configuration yarnConfig; protected final Configuration yarnConfig;
private final LaunchCommandFactory launchCommandFactory; private final LaunchCommandFactory launchCommandFactory;
@ -52,19 +58,55 @@ public abstract class AbstractComponent {
public AbstractComponent(FileSystemOperations fsOperations, public AbstractComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager, RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters, TaskType taskType, RunJobParameters parameters, Role role,
Configuration yarnConfig, Configuration yarnConfig,
LaunchCommandFactory launchCommandFactory) { LaunchCommandFactory launchCommandFactory) {
this.fsOperations = fsOperations; this.fsOperations = fsOperations;
this.remoteDirectoryManager = remoteDirectoryManager; this.remoteDirectoryManager = remoteDirectoryManager;
this.parameters = parameters; this.parameters = parameters;
this.taskType = taskType; this.role = role;
this.launchCommandFactory = launchCommandFactory; this.launchCommandFactory = launchCommandFactory;
this.yarnConfig = yarnConfig; this.yarnConfig = yarnConfig;
} }
protected abstract Component createComponent() throws IOException; protected abstract Component createComponent() throws IOException;
protected Component createComponentInternal() throws IOException {
Objects.requireNonNull(this.parameters.getWorkerResource(),
"Worker resource must not be null!");
if (parameters.getNumWorkers() < 1) {
throw new IllegalArgumentException(
"Number of workers should be at least 1!");
}
Component component = new Component();
component.setName(role.getComponentName());
if (role.equals(TensorFlowRole.PRIMARY_WORKER) ||
role.equals(PyTorchRole.PRIMARY_WORKER)) {
component.setNumberOfContainers(1L);
component.getConfiguration().setProperty(
CONTAINER_STATE_REPORT_AS_SERVICE_STATE, "true");
} else {
component.setNumberOfContainers(
(long) parameters.getNumWorkers() - 1);
}
if (parameters.getWorkerDockerImage() != null) {
component.setArtifact(
getDockerArtifact(parameters.getWorkerDockerImage()));
}
component.setResource(convertYarnResourceToServiceResource(
parameters.getWorkerResource()));
component.setRestartPolicy(Component.RestartPolicyEnum.NEVER);
addCommonEnvironments(component, role);
generateLaunchCommand(component);
return component;
}
/** /**
* Generates a command launch script on local disk, * Generates a command launch script on local disk,
* returns path to the script. * returns path to the script.
@ -72,7 +114,7 @@ public abstract class AbstractComponent {
protected void generateLaunchCommand(Component component) protected void generateLaunchCommand(Component component)
throws IOException { throws IOException {
AbstractLaunchCommand launchCommand = AbstractLaunchCommand launchCommand =
launchCommandFactory.createLaunchCommand(taskType, component); launchCommandFactory.createLaunchCommand(role, component);
this.localScriptFile = launchCommand.generateLaunchScript(); this.localScriptFile = launchCommand.generateLaunchScript();
String remoteLaunchCommand = uploadLaunchCommand(component); String remoteLaunchCommand = uploadLaunchCommand(component);
@ -86,7 +128,7 @@ public abstract class AbstractComponent {
Path stagingDir = Path stagingDir =
remoteDirectoryManager.getJobStagingArea(parameters.getName(), true); remoteDirectoryManager.getJobStagingArea(parameters.getName(), true);
String destScriptFileName = getScriptFileName(taskType); String destScriptFileName = getScriptFileName(role);
fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir, fsOperations.uploadToRemoteFileAndLocalizeToContainerWorkDir(stagingDir,
localScriptFile, destScriptFileName, component); localScriptFile, destScriptFileName, component);

View File

@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
/**
* Abstract base class that supports creating service specs for Native Service.
*/
public abstract class AbstractServiceSpec implements ServiceSpec {
private static final Logger LOG =
LoggerFactory.getLogger(AbstractServiceSpec.class);
protected final RunJobParameters parameters;
protected final FileSystemOperations fsOperations;
private final Localizer localizer;
protected final RemoteDirectoryManager remoteDirectoryManager;
protected final Configuration yarnConfig;
protected final LaunchCommandFactory launchCommandFactory;
private final WorkerComponentFactory workerFactory;
public AbstractServiceSpec(RunJobParameters parameters,
ClientContext clientContext, FileSystemOperations fsOperations,
LaunchCommandFactory launchCommandFactory,
Localizer localizer) {
this.parameters = parameters;
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager();
this.yarnConfig = clientContext.getYarnConfig();
this.fsOperations = fsOperations;
this.localizer = localizer;
this.launchCommandFactory = launchCommandFactory;
this.workerFactory = new WorkerComponentFactory(fsOperations,
remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig);
}
protected ServiceWrapper createServiceSpecWrapper() throws IOException {
Service serviceSpec = new Service();
serviceSpec.setName(parameters.getName());
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
.create(fsOperations, remoteDirectoryManager, parameters);
if (kerberosPrincipal != null) {
serviceSpec.setKerberosPrincipal(kerberosPrincipal);
}
handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
localizer.handleLocalizations(serviceSpec);
return new ServiceWrapper(serviceSpec);
}
// Handle worker and primary_worker.
protected void addWorkerComponents(ServiceWrapper serviceWrapper,
Framework framework)
throws IOException {
final Role primaryWorkerRole;
final Role workerRole;
if (framework == Framework.TENSORFLOW) {
primaryWorkerRole = TensorFlowRole.PRIMARY_WORKER;
workerRole = TensorFlowRole.WORKER;
} else {
primaryWorkerRole = PyTorchRole.PRIMARY_WORKER;
workerRole = PyTorchRole.WORKER;
}
addWorkerComponent(serviceWrapper, primaryWorkerRole, framework);
if (parameters.getNumWorkers() > 1) {
addWorkerComponent(serviceWrapper, workerRole, framework);
}
}
private void addWorkerComponent(ServiceWrapper serviceWrapper,
Role role, Framework framework) throws IOException {
AbstractComponent component = workerFactory.create(framework, role);
serviceWrapper.addComponent(component);
}
protected void handleQuicklinks(Service serviceSpec)
throws IOException {
List<Quicklink> quicklinks = parameters.getQuicklinks();
if (quicklinks != null && !quicklinks.isEmpty()) {
for (Quicklink ql : quicklinks) {
// Make sure it is a valid instance name
String instanceName = ql.getComponentInstanceName();
boolean found = false;
for (Component comp : serviceSpec.getComponents()) {
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
String possibleInstanceName = comp.getName() + "-" + i;
if (possibleInstanceName.equals(instanceName)) {
found = true;
break;
}
}
}
if (!found) {
throw new IOException(
"Couldn't find a component instance = " + instanceName
+ " while adding quicklink");
}
String link = ql.getProtocol()
+ YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
getUserName(), getDNSDomain(yarnConfig), ql.getPort());
addQuicklink(serviceSpec, ql.getLabel(), link);
}
}
}
protected static void addQuicklink(Service serviceSpec, String label,
String link) {
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
if (quicklinks == null) {
quicklinks = new HashMap<>();
serviceSpec.setQuicklinks(quicklinks);
}
if (SubmarineLogs.isVerbose()) {
LOG.info("Added quicklink, " + label + "=" + link);
}
quicklinks.put(label, link);
}
}

View File

@ -37,6 +37,7 @@ import java.io.File;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.io.IOException; import java.io.IOException;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Set; import java.util.Set;
/** /**
@ -195,6 +196,15 @@ public class FileSystemOperations {
fs.setPermission(destPath, new FsPermission(permission)); fs.setPermission(destPath, new FsPermission(permission));
} }
public static boolean needHdfs(List<String> stringsToCheck) {
for (String content : stringsToCheck) {
if (content != null && content.contains("hdfs://")) {
return true;
}
}
return false;
}
public static boolean needHdfs(String content) { public static boolean needHdfs(String content) {
return content != null && content.contains("hdfs://"); return content != null && content.contains("hdfs://");
} }

View File

@ -16,9 +16,10 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice; package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.curator.shaded.com.google.common.collect.ImmutableList;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.ClientContext; import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
@ -28,6 +29,8 @@ import org.slf4j.LoggerFactory;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.util.List;
import java.util.Objects;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs; import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations.needHdfs;
import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath; import static org.apache.hadoop.yarn.submarine.utils.ClassPathUtilities.findFileOnClassPath;
@ -128,10 +131,22 @@ public class HadoopEnvironmentSetup {
} }
private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) { private boolean doesNeedHdfs(RunJobParameters parameters, boolean hadoopEnv) {
return needHdfs(parameters.getInputPath()) || List<String> launchCommands = parameters.getLaunchCommands();
needHdfs(parameters.getPSLaunchCmd()) || if (launchCommands != null) {
needHdfs(parameters.getWorkerLaunchCmd()) || launchCommands.removeIf(Objects::isNull);
hadoopEnv; }
ImmutableList.Builder<String> listBuilder = ImmutableList.builder();
if (launchCommands != null && !launchCommands.isEmpty()) {
listBuilder.addAll(launchCommands);
}
if (parameters.getInputPath() != null) {
listBuilder.add(parameters.getInputPath());
}
List<String> stringsToCheck = listBuilder.build();
return needHdfs(stringsToCheck) || hadoopEnv;
} }
private void appendHdfsHome(PrintWriter fw, String hdfsHome) { private void appendHdfsHome(PrintWriter fw, String hdfsHome) {

View File

@ -38,7 +38,7 @@ public final class ServiceSpecFileGenerator {
"instantiated!"); "instantiated!");
} }
static String generateJson(Service service) throws IOException { public static String generateJson(Service service) throws IOException {
File serviceSpecFile = File.createTempFile(service.getName(), ".json"); File serviceSpecFile = File.createTempFile(service.getName(), ".json");
String buffer = jsonSerDeser.toJson(service); String buffer = jsonSerDeser.toJson(service);
Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile), Writer w = new OutputStreamWriter(new FileOutputStream(serviceSpecFile),

View File

@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component.PyTorchWorkerComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
/**
* Factory class that helps creating Native Service components.
*/
public class WorkerComponentFactory {
private final FileSystemOperations fsOperations;
private final RemoteDirectoryManager remoteDirectoryManager;
private final RunJobParameters parameters;
private final LaunchCommandFactory launchCommandFactory;
private final Configuration yarnConfig;
WorkerComponentFactory(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters,
LaunchCommandFactory launchCommandFactory,
Configuration yarnConfig) {
this.fsOperations = fsOperations;
this.remoteDirectoryManager = remoteDirectoryManager;
this.parameters = parameters;
this.launchCommandFactory = launchCommandFactory;
this.yarnConfig = yarnConfig;
}
/**
* Creates either a TensorFlow or a PyTorch Native Service component.
*/
public AbstractComponent create(Framework framework, Role role) {
if (framework == Framework.TENSORFLOW) {
return new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
(TensorFlowRunJobParameters) parameters, role,
(TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
} else if (framework == Framework.PYTORCH) {
return new PyTorchWorkerComponent(fsOperations, remoteDirectoryManager,
(PyTorchRunJobParameters) parameters, role,
(PyTorchLaunchCommandFactory) launchCommandFactory, yarnConfig);
} else {
throw new UnsupportedOperationException("Only supported frameworks are: "
+ Framework.getValues());
}
}
}

View File

@ -20,10 +20,16 @@ import org.apache.hadoop.yarn.client.api.AppAdminClient;
import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.service.api.records.Service; import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.service.utils.ServiceApiUtil; import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext; import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter; import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.PyTorchServiceSpec;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
import org.apache.hadoop.yarn.submarine.utils.Localizer; import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -32,6 +38,7 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS; import static org.apache.hadoop.yarn.service.exceptions.LauncherExitCodes.EXIT_SUCCESS;
import static org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder.SUPPORTED_FRAMEWORKS_MESSAGE;
/** /**
* Submit a job to cluster. * Submit a job to cluster.
@ -51,14 +58,45 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
* {@inheritDoc} * {@inheritDoc}
*/ */
@Override @Override
public ApplicationId submitJob(RunJobParameters parameters) public ApplicationId submitJob(ParametersHolder paramsHolder)
throws IOException, YarnException { throws IOException, YarnException {
Framework framework = paramsHolder.getFramework();
RunJobParameters parameters =
(RunJobParameters) paramsHolder.getParameters();
if (framework == Framework.TENSORFLOW) {
return submitTensorFlowJob((TensorFlowRunJobParameters) parameters);
} else if (framework == Framework.PYTORCH) {
return submitPyTorchJob((PyTorchRunJobParameters) parameters);
} else {
throw new UnsupportedOperationException(SUPPORTED_FRAMEWORKS_MESSAGE);
}
}
private ApplicationId submitTensorFlowJob(
TensorFlowRunJobParameters parameters) throws IOException, YarnException {
FileSystemOperations fsOperations = new FileSystemOperations(clientContext); FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
HadoopEnvironmentSetup hadoopEnvSetup = HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(clientContext, fsOperations); new HadoopEnvironmentSetup(clientContext, fsOperations);
Service serviceSpec = createTensorFlowServiceSpec(parameters, Service serviceSpec = createTensorFlowServiceSpec(parameters,
fsOperations, hadoopEnvSetup); fsOperations, hadoopEnvSetup);
return submitJobInternal(serviceSpec);
}
private ApplicationId submitPyTorchJob(PyTorchRunJobParameters parameters)
throws IOException, YarnException {
FileSystemOperations fsOperations = new FileSystemOperations(clientContext);
HadoopEnvironmentSetup hadoopEnvSetup =
new HadoopEnvironmentSetup(clientContext, fsOperations);
Service serviceSpec = createPyTorchServiceSpec(parameters,
fsOperations, hadoopEnvSetup);
return submitJobInternal(serviceSpec);
}
private ApplicationId submitJobInternal(Service serviceSpec)
throws IOException, YarnException {
String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec); String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
AppAdminClient appAdminClient = AppAdminClient appAdminClient =
@ -70,7 +108,7 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
"Fail to launch application with exit code:" + code); "Fail to launch application with exit code:" + code);
} }
String appStatus=appAdminClient.getStatusString(serviceSpec.getName()); String appStatus = appAdminClient.getStatusString(serviceSpec.getName());
Service app = ServiceApiUtil.jsonSerDeser.fromJson(appStatus); Service app = ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
// Retry multiple times if applicationId is null // Retry multiple times if applicationId is null
@ -97,11 +135,12 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
return appid; return appid;
} }
private Service createTensorFlowServiceSpec(RunJobParameters parameters, private Service createTensorFlowServiceSpec(
TensorFlowRunJobParameters parameters,
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup) FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
throws IOException { throws IOException {
LaunchCommandFactory launchCommandFactory = TensorFlowLaunchCommandFactory launchCommandFactory =
new LaunchCommandFactory(hadoopEnvSetup, parameters, new TensorFlowLaunchCommandFactory(hadoopEnvSetup, parameters,
clientContext.getYarnConfig()); clientContext.getYarnConfig());
Localizer localizer = new Localizer(fsOperations, Localizer localizer = new Localizer(fsOperations,
clientContext.getRemoteDirectoryManager(), parameters); clientContext.getRemoteDirectoryManager(), parameters);
@ -113,6 +152,22 @@ public class YarnServiceJobSubmitter implements JobSubmitter {
return serviceWrapper.getService(); return serviceWrapper.getService();
} }
private Service createPyTorchServiceSpec(PyTorchRunJobParameters parameters,
FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup)
throws IOException {
PyTorchLaunchCommandFactory launchCommandFactory =
new PyTorchLaunchCommandFactory(hadoopEnvSetup, parameters,
clientContext.getYarnConfig());
Localizer localizer = new Localizer(fsOperations,
clientContext.getRemoteDirectoryManager(), parameters);
PyTorchServiceSpec pyTorchServiceSpec = new PyTorchServiceSpec(
parameters, this.clientContext, fsOperations, launchCommandFactory,
localizer);
serviceWrapper = pyTorchServiceSpec.create();
return serviceWrapper.getService();
}
@VisibleForTesting @VisibleForTesting
public ServiceWrapper getServiceWrapper() { public ServiceWrapper getServiceWrapper() {
return serviceWrapper; return serviceWrapper;

View File

@ -17,11 +17,9 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command; package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import java.io.IOException; import java.io.IOException;
import java.util.Objects;
/** /**
* Abstract base class for Launch command implementations for Services. * Abstract base class for Launch command implementations for Services.
@ -32,10 +30,9 @@ public abstract class AbstractLaunchCommand {
private final LaunchScriptBuilder builder; private final LaunchScriptBuilder builder;
public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup, public AbstractLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters) Component component, RunJobParameters parameters,
throws IOException { String launchCommandPrefix) throws IOException {
Objects.requireNonNull(taskType, "TaskType must not be null!"); this.builder = new LaunchScriptBuilder(launchCommandPrefix, hadoopEnvSetup,
this.builder = new LaunchScriptBuilder(taskType.name(), hadoopEnvSetup,
parameters, component); parameters, component);
} }

View File

@ -16,52 +16,15 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command; package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.api.TaskType;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
import java.io.IOException; import java.io.IOException;
import java.util.Objects;
/** /**
* Simple factory to create instances of {@link AbstractLaunchCommand} * Interface for creating launch commands.
* based on the {@link TaskType}.
* All dependencies are passed to this factory that could be required
* by any implementor of {@link AbstractLaunchCommand}.
*/ */
public class LaunchCommandFactory { public interface LaunchCommandFactory {
private final HadoopEnvironmentSetup hadoopEnvSetup; AbstractLaunchCommand createLaunchCommand(Role role, Component component)
private final RunJobParameters parameters; throws IOException;
private final Configuration yarnConfig;
public LaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
RunJobParameters parameters, Configuration yarnConfig) {
this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters;
this.yarnConfig = yarnConfig;
}
public AbstractLaunchCommand createLaunchCommand(TaskType taskType,
Component component) throws IOException {
Objects.requireNonNull(taskType, "TaskType must not be null!");
if (taskType == TaskType.WORKER || taskType == TaskType.PRIMARY_WORKER) {
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, taskType,
component, parameters, yarnConfig);
} else if (taskType == TaskType.PS) {
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, taskType, component,
parameters, yarnConfig);
} else if (taskType == TaskType.TENSORBOARD) {
return new TensorBoardLaunchCommand(hadoopEnvSetup, taskType, component,
parameters);
}
throw new IllegalStateException("Unknown task type: " + taskType);
}
} }

View File

@ -17,7 +17,7 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command; package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -47,10 +47,11 @@ public class LaunchScriptBuilder {
private final StringBuilder scriptBuffer; private final StringBuilder scriptBuffer;
private String launchCommand; private String launchCommand;
LaunchScriptBuilder(String namePrefix, LaunchScriptBuilder(String launchScriptPrefix,
HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters, HadoopEnvironmentSetup hadoopEnvSetup, RunJobParameters parameters,
Component component) throws IOException { Component component) throws IOException {
this.file = File.createTempFile(namePrefix + "-launch-script", ".sh"); this.file = File.createTempFile(launchScriptPrefix +
"-launch-script", ".sh");
this.hadoopEnvSetup = hadoopEnvSetup; this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters; this.parameters = parameters;
this.component = component; this.component = component;

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import java.io.IOException;
import java.util.Objects;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.PyTorchRole;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command.PyTorchWorkerLaunchCommand;
/**
* Simple factory to create instances of {@link AbstractLaunchCommand}
* based on the {@link Role}.
* All dependencies are passed to this factory that could be required
* by any implementor of {@link AbstractLaunchCommand}.
*/
public class PyTorchLaunchCommandFactory implements LaunchCommandFactory {
private final HadoopEnvironmentSetup hadoopEnvSetup;
private final PyTorchRunJobParameters parameters;
private final Configuration yarnConfig;
public PyTorchLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
PyTorchRunJobParameters parameters, Configuration yarnConfig) {
this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters;
this.yarnConfig = yarnConfig;
}
public AbstractLaunchCommand createLaunchCommand(Role role,
Component component) throws IOException {
Objects.requireNonNull(role, "Role must not be null!");
if (role == PyTorchRole.WORKER ||
role == PyTorchRole.PRIMARY_WORKER) {
return new PyTorchWorkerLaunchCommand(hadoopEnvSetup, role,
component, parameters, yarnConfig);
} else {
throw new IllegalStateException("Unknown task type: " + role);
}
}
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorBoardLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowPsLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command.TensorFlowWorkerLaunchCommand;
import java.io.IOException;
import java.util.Objects;
/**
* Simple factory to create instances of {@link AbstractLaunchCommand}
* based on the {@link Role}.
* All dependencies are passed to this factory that could be required
* by any implementor of {@link AbstractLaunchCommand}.
*/
public class TensorFlowLaunchCommandFactory implements LaunchCommandFactory {
private final HadoopEnvironmentSetup hadoopEnvSetup;
private final TensorFlowRunJobParameters parameters;
private final Configuration yarnConfig;
public TensorFlowLaunchCommandFactory(HadoopEnvironmentSetup hadoopEnvSetup,
TensorFlowRunJobParameters parameters, Configuration yarnConfig) {
this.hadoopEnvSetup = hadoopEnvSetup;
this.parameters = parameters;
this.yarnConfig = yarnConfig;
}
@Override
public AbstractLaunchCommand createLaunchCommand(Role role,
Component component) throws IOException {
Objects.requireNonNull(role, "Role must not be null!");
if (role == TensorFlowRole.WORKER ||
role == TensorFlowRole.PRIMARY_WORKER) {
return new TensorFlowWorkerLaunchCommand(hadoopEnvSetup, role,
component, parameters, yarnConfig);
} else if (role == TensorFlowRole.PS) {
return new TensorFlowPsLaunchCommand(hadoopEnvSetup, role, component,
parameters, yarnConfig);
} else if (role == TensorFlowRole.TENSORBOARD) {
return new TensorBoardLaunchCommand(hadoopEnvSetup, role, component,
parameters);
}
throw new IllegalStateException("Unknown task type: " + role);
}
}

View File

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;
import java.io.IOException;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This class contains all the logic to create an instance
* of a {@link Service} object for PyTorch.
* Please note that currently, only single-node (non-distributed)
* support is implemented for PyTorch.
*/
public class PyTorchServiceSpec extends AbstractServiceSpec {
private static final Logger LOG =
LoggerFactory.getLogger(PyTorchServiceSpec.class);
//this field is needed in the future!
private final PyTorchRunJobParameters pyTorchParameters;
public PyTorchServiceSpec(PyTorchRunJobParameters parameters,
ClientContext clientContext, FileSystemOperations fsOperations,
PyTorchLaunchCommandFactory launchCommandFactory, Localizer localizer) {
super(parameters, clientContext, fsOperations, launchCommandFactory,
localizer);
this.pyTorchParameters = parameters;
}
@Override
public ServiceWrapper create() throws IOException {
LOG.info("Creating PyTorch service spec");
ServiceWrapper serviceWrapper = createServiceSpecWrapper();
if (parameters.getNumWorkers() > 0) {
addWorkerComponents(serviceWrapper, Framework.PYTORCH);
}
// After all components added, handle quicklinks
handleQuicklinks(serviceWrapper.getService());
return serviceWrapper;
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;
import java.io.IOException;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Launch command implementation for PyTorch components.
*/
public class PyTorchWorkerLaunchCommand extends AbstractLaunchCommand {
private static final Logger LOG =
LoggerFactory.getLogger(PyTorchWorkerLaunchCommand.class);
private final Configuration yarnConfig;
private final boolean distributed;
private final int numberOfWorkers;
private final String name;
private final Role role;
private final String launchCommand;
public PyTorchWorkerLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
Role role, Component component,
PyTorchRunJobParameters parameters,
Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, component, parameters, role.getName());
this.role = role;
this.name = parameters.getName();
this.distributed = parameters.isDistributed();
this.numberOfWorkers = parameters.getNumWorkers();
this.yarnConfig = yarnConfig;
logReceivedParameters();
this.launchCommand = parameters.getWorkerLaunchCmd();
if (StringUtils.isEmpty(this.launchCommand)) {
throw new IllegalArgumentException("LaunchCommand must not be null " +
"or empty!");
}
}
private void logReceivedParameters() {
if (this.numberOfWorkers <= 0) {
LOG.warn("Received number of workers: {}", this.numberOfWorkers);
}
}
@Override
public String generateLaunchScript() throws IOException {
LaunchScriptBuilder builder = getBuilder();
return builder
.withLaunchCommand(createLaunchCommand())
.build();
}
@Override
public String createLaunchCommand() {
if (SubmarineLogs.isVerbose()) {
LOG.info("PyTorch Worker command =[" + launchCommand + "]");
}
return launchCommand + '\n';
}
}

View File

@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to generate PyTorch launch commands.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.command;

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
import java.io.IOException;
/**
* Component implementation for Worker process of PyTorch.
*/
public class PyTorchWorkerComponent extends AbstractComponent {
public PyTorchWorkerComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager,
PyTorchRunJobParameters parameters, Role role,
PyTorchLaunchCommandFactory launchCommandFactory,
Configuration yarnConfig) {
super(fsOperations, remoteDirectoryManager, parameters, role,
yarnConfig, launchCommandFactory);
}
@Override
public Component createComponent() throws IOException {
return createComponentInternal();
}
}

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to generate
* PyTorch Native Service components.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.component;

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package contains classes to generate
* PyTorch-related Native Service runtime artifacts.
*/
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch;

View File

@ -20,7 +20,7 @@ import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.ServiceApiConstants; import org.apache.hadoop.yarn.service.api.ServiceApiConstants;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.common.Envs; import org.apache.hadoop.yarn.submarine.common.Envs;
import org.apache.hadoop.yarn.submarine.common.api.TaskType; import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import java.util.Map; import java.util.Map;
@ -35,10 +35,10 @@ public final class TensorFlowCommons {
} }
public static void addCommonEnvironments(Component component, public static void addCommonEnvironments(Component component,
TaskType taskType) { Role role) {
Map<String, String> envs = component.getConfiguration().getEnv(); Map<String, String> envs = component.getConfiguration().getEnv();
envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID); envs.put(Envs.TASK_INDEX_ENV, ServiceApiConstants.COMPONENT_ID);
envs.put(Envs.TASK_TYPE_ENV, taskType.name()); envs.put(Envs.TASK_TYPE_ENV, role.getName());
} }
public static String getUserName() { public static String getUserName() {
@ -49,8 +49,8 @@ public final class TensorFlowCommons {
return yarnConfig.get("hadoop.registry.dns.domain-name"); return yarnConfig.get("hadoop.registry.dns.domain-name");
} }
public static String getScriptFileName(TaskType taskType) { public static String getScriptFileName(Role role) {
return "run-" + taskType.name() + ".sh"; return "run-" + role.getName() + ".sh";
} }
public static String getTFConfigEnv(String componentName, int nWorkers, public static String getTFConfigEnv(String componentName, int nWorkers,

View File

@ -16,39 +16,24 @@
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow; package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.KerberosPrincipal;
import org.apache.hadoop.yarn.service.api.records.Service; import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.submarine.client.cli.param.Quicklink; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext; import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.common.api.TaskType; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractServiceSpec;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpec;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowPsComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorFlowWorkerComponent;
import org.apache.hadoop.yarn.submarine.utils.KerberosPrincipalFactory;
import org.apache.hadoop.yarn.submarine.utils.Localizer; import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getDNSDomain;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons.getUserName;
import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL; import static org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.component.TensorBoardComponent.TENSORBOARD_QUICKLINK_LABEL;
import static org.apache.hadoop.yarn.submarine.utils.DockerUtilities.getDockerArtifact;
import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handleServiceEnvs;
/** /**
* This class contains all the logic to create an instance * This class contains all the logic to create an instance
@ -56,42 +41,34 @@ import static org.apache.hadoop.yarn.submarine.utils.EnvironmentUtilities.handle
* Worker,PS and Tensorboard components are added to the Service * Worker,PS and Tensorboard components are added to the Service
* based on the value of the received {@link RunJobParameters}. * based on the value of the received {@link RunJobParameters}.
*/ */
public class TensorFlowServiceSpec implements ServiceSpec { public class TensorFlowServiceSpec extends AbstractServiceSpec {
private static final Logger LOG = private static final Logger LOG =
LoggerFactory.getLogger(TensorFlowServiceSpec.class); LoggerFactory.getLogger(TensorFlowServiceSpec.class);
private final TensorFlowRunJobParameters tensorFlowParameters;
private final RemoteDirectoryManager remoteDirectoryManager; public TensorFlowServiceSpec(TensorFlowRunJobParameters parameters,
private final RunJobParameters parameters;
private final Configuration yarnConfig;
private final FileSystemOperations fsOperations;
private final LaunchCommandFactory launchCommandFactory;
private final Localizer localizer;
public TensorFlowServiceSpec(RunJobParameters parameters,
ClientContext clientContext, FileSystemOperations fsOperations, ClientContext clientContext, FileSystemOperations fsOperations,
LaunchCommandFactory launchCommandFactory, Localizer localizer) { TensorFlowLaunchCommandFactory launchCommandFactory,
this.parameters = parameters; Localizer localizer) {
this.remoteDirectoryManager = clientContext.getRemoteDirectoryManager(); super(parameters, clientContext, fsOperations, launchCommandFactory,
this.yarnConfig = clientContext.getYarnConfig(); localizer);
this.fsOperations = fsOperations; this.tensorFlowParameters = parameters;
this.launchCommandFactory = launchCommandFactory;
this.localizer = localizer;
} }
@Override @Override
public ServiceWrapper create() throws IOException { public ServiceWrapper create() throws IOException {
LOG.info("Creating TensorFlow service spec");
ServiceWrapper serviceWrapper = createServiceSpecWrapper(); ServiceWrapper serviceWrapper = createServiceSpecWrapper();
if (parameters.getNumWorkers() > 0) { if (tensorFlowParameters.getNumWorkers() > 0) {
addWorkerComponents(serviceWrapper); addWorkerComponents(serviceWrapper, Framework.TENSORFLOW);
} }
if (parameters.getNumPS() > 0) { if (tensorFlowParameters.getNumPS() > 0) {
addPsComponent(serviceWrapper); addPsComponent(serviceWrapper);
} }
if (parameters.isTensorboardEnabled()) { if (tensorFlowParameters.isTensorboardEnabled()) {
createTensorBoardComponent(serviceWrapper); createTensorBoardComponent(serviceWrapper);
} }
@ -101,103 +78,23 @@ public class TensorFlowServiceSpec implements ServiceSpec {
return serviceWrapper; return serviceWrapper;
} }
private ServiceWrapper createServiceSpecWrapper() throws IOException {
Service serviceSpec = new Service();
serviceSpec.setName(parameters.getName());
serviceSpec.setVersion(String.valueOf(System.currentTimeMillis()));
serviceSpec.setArtifact(getDockerArtifact(parameters.getDockerImageName()));
KerberosPrincipal kerberosPrincipal = KerberosPrincipalFactory
.create(fsOperations, remoteDirectoryManager, parameters);
if (kerberosPrincipal != null) {
serviceSpec.setKerberosPrincipal(kerberosPrincipal);
}
handleServiceEnvs(serviceSpec, yarnConfig, parameters.getEnvars());
localizer.handleLocalizations(serviceSpec);
return new ServiceWrapper(serviceSpec);
}
private void createTensorBoardComponent(ServiceWrapper serviceWrapper) private void createTensorBoardComponent(ServiceWrapper serviceWrapper)
throws IOException { throws IOException {
TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations, TensorBoardComponent tbComponent = new TensorBoardComponent(fsOperations,
remoteDirectoryManager, parameters, launchCommandFactory, yarnConfig); remoteDirectoryManager, parameters,
(TensorFlowLaunchCommandFactory) launchCommandFactory, yarnConfig);
serviceWrapper.addComponent(tbComponent); serviceWrapper.addComponent(tbComponent);
addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL, addQuicklink(serviceWrapper.getService(), TENSORBOARD_QUICKLINK_LABEL,
tbComponent.getTensorboardLink()); tbComponent.getTensorboardLink());
} }
private static void addQuicklink(Service serviceSpec, String label,
String link) {
Map<String, String> quicklinks = serviceSpec.getQuicklinks();
if (quicklinks == null) {
quicklinks = new HashMap<>();
serviceSpec.setQuicklinks(quicklinks);
}
if (SubmarineLogs.isVerbose()) {
LOG.info("Added quicklink, " + label + "=" + link);
}
quicklinks.put(label, link);
}
private void handleQuicklinks(Service serviceSpec)
throws IOException {
List<Quicklink> quicklinks = parameters.getQuicklinks();
if (quicklinks != null && !quicklinks.isEmpty()) {
for (Quicklink ql : quicklinks) {
// Make sure it is a valid instance name
String instanceName = ql.getComponentInstanceName();
boolean found = false;
for (Component comp : serviceSpec.getComponents()) {
for (int i = 0; i < comp.getNumberOfContainers(); i++) {
String possibleInstanceName = comp.getName() + "-" + i;
if (possibleInstanceName.equals(instanceName)) {
found = true;
break;
}
}
}
if (!found) {
throw new IOException(
"Couldn't find a component instance = " + instanceName
+ " while adding quicklink");
}
String link = ql.getProtocol()
+ YarnServiceUtils.getDNSName(serviceSpec.getName(), instanceName,
getUserName(), getDNSDomain(yarnConfig), ql.getPort());
addQuicklink(serviceSpec, ql.getLabel(), link);
}
}
}
// Handle worker and primary_worker.
private void addWorkerComponents(ServiceWrapper serviceWrapper)
throws IOException {
addWorkerComponent(serviceWrapper, parameters, TaskType.PRIMARY_WORKER);
if (parameters.getNumWorkers() > 1) {
addWorkerComponent(serviceWrapper, parameters, TaskType.WORKER);
}
}
private void addWorkerComponent(ServiceWrapper serviceWrapper,
RunJobParameters parameters, TaskType taskType) throws IOException {
serviceWrapper.addComponent(
new TensorFlowWorkerComponent(fsOperations, remoteDirectoryManager,
parameters, taskType, launchCommandFactory, yarnConfig));
}
private void addPsComponent(ServiceWrapper serviceWrapper) private void addPsComponent(ServiceWrapper serviceWrapper)
throws IOException { throws IOException {
serviceWrapper.addComponent( serviceWrapper.addComponent(
new TensorFlowPsComponent(fsOperations, remoteDirectoryManager, new TensorFlowPsComponent(fsOperations, remoteDirectoryManager,
launchCommandFactory, parameters, yarnConfig)); (TensorFlowLaunchCommandFactory) launchCommandFactory,
parameters, yarnConfig));
} }
} }

View File

@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType; import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -37,9 +37,9 @@ public class TensorBoardLaunchCommand extends AbstractLaunchCommand {
private final String checkpointPath; private final String checkpointPath;
public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup, public TensorBoardLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters) Role role, Component component, RunJobParameters parameters)
throws IOException { throws IOException {
super(hadoopEnvSetup, taskType, component, parameters); super(hadoopEnvSetup, component, parameters, role.getName());
Objects.requireNonNull(parameters.getCheckpointPath(), Objects.requireNonNull(parameters.getCheckpointPath(),
"CheckpointPath must not be null as it is part " "CheckpointPath must not be null as it is part "
+ "of the tensorboard command!"); + "of the tensorboard command!");

View File

@ -18,8 +18,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType; import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
@ -28,6 +28,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.Objects;
/** /**
* Launch command implementation for * Launch command implementation for
@ -41,13 +42,16 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
private final int numberOfWorkers; private final int numberOfWorkers;
private final int numberOfPS; private final int numberOfPS;
private final String name; private final String name;
private final TaskType taskType; private final Role role;
TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup, TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters, Role role, Component component,
TensorFlowRunJobParameters parameters,
Configuration yarnConfig) throws IOException { Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, taskType, component, parameters); super(hadoopEnvSetup, component, parameters,
this.taskType = taskType; role != null ? role.getName(): "");
Objects.requireNonNull(role, "TensorFlowRole must not be null!");
this.role = role;
this.name = parameters.getName(); this.name = parameters.getName();
this.distributed = parameters.isDistributed(); this.distributed = parameters.isDistributed();
this.numberOfWorkers = parameters.getNumWorkers(); this.numberOfWorkers = parameters.getNumWorkers();
@ -72,7 +76,7 @@ public abstract class TensorFlowLaunchCommand extends AbstractLaunchCommand {
// When distributed training is required // When distributed training is required
if (distributed) { if (distributed) {
String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv( String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(
taskType.getComponentName(), numberOfWorkers, role.getComponentName(), numberOfWorkers,
numberOfPS, name, numberOfPS, name,
TensorFlowCommons.getUserName(), TensorFlowCommons.getUserName(),
TensorFlowCommons.getDNSDomain(yarnConfig)); TensorFlowCommons.getDNSDomain(yarnConfig));

View File

@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType; import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -37,9 +37,10 @@ public class TensorFlowPsLaunchCommand extends TensorFlowLaunchCommand {
private final String launchCommand; private final String launchCommand;
public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup, public TensorFlowPsLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup,
TaskType taskType, Component component, RunJobParameters parameters, Role role, Component component,
TensorFlowRunJobParameters parameters,
Configuration yarnConfig) throws IOException { Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig); super(hadoopEnvSetup, role, component, parameters, yarnConfig);
this.launchCommand = parameters.getPSLaunchCmd(); this.launchCommand = parameters.getPSLaunchCmd();
if (StringUtils.isEmpty(this.launchCommand)) { if (StringUtils.isEmpty(this.launchCommand)) {

View File

@ -19,8 +19,8 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType; import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs; import org.apache.hadoop.yarn.submarine.common.conf.SubmarineLogs;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -37,10 +37,10 @@ public class TensorFlowWorkerLaunchCommand extends TensorFlowLaunchCommand {
private final String launchCommand; private final String launchCommand;
public TensorFlowWorkerLaunchCommand( public TensorFlowWorkerLaunchCommand(
HadoopEnvironmentSetup hadoopEnvSetup, TaskType taskType, HadoopEnvironmentSetup hadoopEnvSetup, Role role,
Component component, RunJobParameters parameters, Component component, TensorFlowRunJobParameters parameters,
Configuration yarnConfig) throws IOException { Configuration yarnConfig) throws IOException {
super(hadoopEnvSetup, taskType, component, parameters, yarnConfig); super(hadoopEnvSetup, role, component, parameters, yarnConfig);
this.launchCommand = parameters.getWorkerLaunchCmd(); this.launchCommand = parameters.getWorkerLaunchCmd();
if (StringUtils.isEmpty(this.launchCommand)) { if (StringUtils.isEmpty(this.launchCommand)) {

View File

@ -19,13 +19,14 @@ package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.compone
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component; import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum; import org.apache.hadoop.yarn.service.api.records.Component.RestartPolicyEnum;
import org.apache.hadoop.yarn.submarine.client.cli.param.RunJobParameters; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TaskType; import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.TensorFlowRole;
import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager; import org.apache.hadoop.yarn.submarine.common.fs.RemoteDirectoryManager;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.AbstractComponent;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchCommandFactory; import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -54,35 +55,38 @@ public class TensorBoardComponent extends AbstractComponent {
public TensorBoardComponent(FileSystemOperations fsOperations, public TensorBoardComponent(FileSystemOperations fsOperations,
RemoteDirectoryManager remoteDirectoryManager, RemoteDirectoryManager remoteDirectoryManager,
RunJobParameters parameters, RunJobParameters parameters,
LaunchCommandFactory launchCommandFactory, TensorFlowLaunchCommandFactory launchCommandFactory,
Configuration yarnConfig) { Configuration yarnConfig) {
super(fsOperations, remoteDirectoryManager, parameters, super(fsOperations, remoteDirectoryManager, parameters,
TaskType.TENSORBOARD, yarnConfig, launchCommandFactory); TensorFlowRole.TENSORBOARD, yarnConfig, launchCommandFactory);
} }
@Override @Override
public Component createComponent() throws IOException { public Component createComponent() throws IOException {
Objects.requireNonNull(parameters.getTensorboardResource(), TensorFlowRunJobParameters tensorFlowParams =
(TensorFlowRunJobParameters) this.parameters;
Objects.requireNonNull(tensorFlowParams.getTensorboardResource(),
"TensorBoard resource must not be null!"); "TensorBoard resource must not be null!");
Component component = new Component(); Component component = new Component();
component.setName(taskType.getComponentName()); component.setName(role.getComponentName());
component.setNumberOfContainers(1L); component.setNumberOfContainers(1L);
component.setRestartPolicy(RestartPolicyEnum.NEVER); component.setRestartPolicy(RestartPolicyEnum.NEVER);
component.setResource(convertYarnResourceToServiceResource( component.setResource(convertYarnResourceToServiceResource(
parameters.getTensorboardResource())); tensorFlowParams.getTensorboardResource()));
if (parameters.getTensorboardDockerImage() != null) { if (tensorFlowParams.getTensorboardDockerImage() != null) {
component.setArtifact( component.setArtifact(
getDockerArtifact(parameters.getTensorboardDockerImage())); getDockerArtifact(tensorFlowParams.getTensorboardDockerImage()));
} }
addCommonEnvironments(component, taskType); addCommonEnvironments(component, role);
generateLaunchCommand(component); generateLaunchCommand(component);
tensorboardLink = "http://" + YarnServiceUtils.getDNSName( tensorboardLink = "http://" + YarnServiceUtils.getDNSName(
parameters.getName(), parameters.getName(),
taskType.getComponentName() + "-" + 0, getUserName(), role.getComponentName() + "-" + 0, getUserName(),
getDNSDomain(yarnConfig), DEFAULT_PORT); getDNSDomain(yarnConfig), DEFAULT_PORT);
LOG.info("Link to tensorboard:" + tensorboardLink); LOG.info("Link to tensorboard:" + tensorboardLink);

Some files were not shown because too many files have changed in this diff Show More