YARN-9060. [YARN-8851] Phase 1 - Support device isolation and use the Nvidia GPU plugin as an example. Contributed by Zhankun Tang.

This commit is contained in:
Sunil G 2019-02-18 15:57:11 +05:30
parent 0f2b65c3da
commit db4d1a1e2f
20 changed files with 1970 additions and 89 deletions

View File

@ -22,4 +22,9 @@ feature.tc.enabled=false
#[fpga]
# module.enabled=## Enable/Disable the FPGA resource handler module. set to "true" to enable, disabled by default
# fpga.major-device-number=## Major device number of FPGA, by default is 246. Strongly recommend setting this
# fpga.allowed-device-minor-numbers=## Comma separated allowed minor device numbers, empty means all FPGA devices managed by YARN.
# fpga.allowed-device-minor-numbers=## Comma separated allowed minor device numbers, empty means all FPGA devices managed by YARN.
# The configs below deal with settings for resource handled by pluggable device plugin framework
#[devices]
# module.enabled=## Enable/Disable the device resource handler module for isolation. Disabled by default.
# devices.denied-numbers=## Blacklisted devices not permitted to use. The format is comma separated "majorNumber:minorNumber". For instance, "195:1,195:2". Leave it empty means default devices reported by device plugin are all allowed.

View File

@ -135,6 +135,7 @@ add_library(container
main/native/container-executor/impl/modules/common/module-configs.c
main/native/container-executor/impl/modules/gpu/gpu-module.c
main/native/container-executor/impl/modules/fpga/fpga-module.c
main/native/container-executor/impl/modules/devices/devices-module.c
main/native/container-executor/impl/utils/docker-util.c
)
@ -169,6 +170,7 @@ add_executable(cetest
main/native/container-executor/test/modules/cgroups/test-cgroups-module.cc
main/native/container-executor/test/modules/gpu/test-gpu-module.cc
main/native/container-executor/test/modules/fpga/test-fpga-module.cc
main/native/container-executor/test/modules/devices/test-devices-module.cc
main/native/container-executor/test/test_util.cc
main/native/container-executor/test/utils/test_docker_util.cc)
target_link_libraries(cetest gtest container)

View File

@ -181,8 +181,8 @@ public final class Device implements Serializable, Comparable {
// default -1 representing the value is not set
private int id = -1;
private String devPath = "";
private int majorNumber;
private int minorNumber;
private int majorNumber = -1;
private int minorNumber = -1;
private String busID = "";
private boolean isHealthy;
private String status = "";

View File

@ -54,6 +54,7 @@ public class PrivilegedOperation {
RUN_DOCKER_CMD("--run-docker"),
GPU("--module-gpu"),
FPGA("--module-fpga"),
DEVICE("--module-devices"),
LIST_AS_USER(""), // no CLI switch supported yet.
ADD_NUMA_PARAMS(""), // no CLI switch supported yet.
REMOVE_DOCKER_CONTAINER("--remove-docker-container"),

View File

@ -0,0 +1,240 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;
import org.apache.hadoop.util.Shell;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
/**
* Nvidia GPU plugin supporting both Nvidia container runtime v2 for Docker and
* non-Docker container.
* */
public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin {
public static final Logger LOG = LoggerFactory.getLogger(
NvidiaGPUPluginForRuntimeV2.class);
public static final String NV_RESOURCE_NAME = "nvidia.com/gpu";
private NvidiaCommandExecutor shellExecutor = new NvidiaCommandExecutor();
private Map<String, String> environment = new HashMap<>();
// If this environment is set, use it directly
private static final String ENV_BINARY_PATH = "NVIDIA_SMI_PATH";
private static final String DEFAULT_BINARY_NAME = "nvidia-smi";
private static final String DEV_NAME_PREFIX = "nvidia";
private String pathOfGpuBinary = null;
// command should not run more than 10 sec.
private static final int MAX_EXEC_TIMEOUT_MS = 10 * 1000;
// When executable path not set, try to search default dirs
// By default search /usr/bin, /bin, and /usr/local/nvidia/bin (when
// launched by nvidia-docker.
private static final Set<String> DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of(
"/usr/bin", "/bin", "/usr/local/nvidia/bin");
@Override
public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
return DeviceRegisterRequest.Builder.newInstance()
.setResourceName(NV_RESOURCE_NAME).build();
}
@Override
public Set<Device> getDevices() throws Exception {
shellExecutor.searchBinary();
TreeSet<Device> r = new TreeSet<>();
String output;
try {
output = shellExecutor.getDeviceInfo();
String[] lines = output.trim().split("\n");
int id = 0;
for (String oneLine : lines) {
String[] tokensEachLine = oneLine.split(",");
if (tokensEachLine.length != 2) {
throw new Exception("Cannot parse the output to get device info. "
+ "Unexpected format in it:" + oneLine);
}
String minorNumber = tokensEachLine[0].trim();
String busId = tokensEachLine[1].trim();
String majorNumber = getMajorNumber(DEV_NAME_PREFIX
+ minorNumber);
if (majorNumber != null) {
r.add(Device.Builder.newInstance()
.setId(id)
.setMajorNumber(Integer.parseInt(majorNumber))
.setMinorNumber(Integer.parseInt(minorNumber))
.setBusID(busId)
.setDevPath("/dev/" + DEV_NAME_PREFIX + minorNumber)
.setHealthy(true)
.build());
id++;
}
}
return r;
} catch (IOException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Failed to get output from " + pathOfGpuBinary);
}
throw new YarnException(e);
}
}
@Override
public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices,
YarnRuntimeType yarnRuntime) throws Exception {
if (LOG.isDebugEnabled()) {
LOG.debug("Generating runtime spec for allocated devices: "
+ allocatedDevices + ", " + yarnRuntime.getName());
}
if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {
String nvidiaRuntime = "nvidia";
String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES";
StringBuffer gpuMinorNumbersSB = new StringBuffer();
for (Device device : allocatedDevices) {
gpuMinorNumbersSB.append(device.getMinorNumber() + ",");
}
String minorNumbers = gpuMinorNumbersSB.toString();
LOG.info("Nvidia Docker v2 assigned GPU: " + minorNumbers);
return DeviceRuntimeSpec.Builder.newInstance()
.addEnv(nvidiaVisibleDevices,
minorNumbers.substring(0, minorNumbers.length() - 1))
.setContainerRuntime(nvidiaRuntime)
.build();
}
return null;
}
@Override
public void onDevicesReleased(Set<Device> releasedDevices) throws Exception {
// do nothing
}
// Get major number from device name.
private String getMajorNumber(String devName) {
String output = null;
// output "major:minor" in hex
try {
if (LOG.isDebugEnabled()) {
LOG.debug("Get major numbers from /dev/" + devName);
}
output = shellExecutor.getMajorMinorInfo(devName);
String[] strs = output.trim().split(":");
if (LOG.isDebugEnabled()) {
LOG.debug("stat output:" + output);
}
output = Integer.toString(Integer.parseInt(strs[0], 16));
} catch (IOException e) {
String msg =
"Failed to get major number from reading /dev/" + devName;
LOG.warn(msg);
} catch (NumberFormatException e) {
LOG.error("Failed to parse device major number from stat output");
output = null;
}
return output;
}
/**
* A shell wrapper class easy for test.
* */
public class NvidiaCommandExecutor {
public String getDeviceInfo() throws IOException {
return Shell.execCommand(environment,
new String[]{pathOfGpuBinary, "--query-gpu=index,pci.bus_id",
"--format=csv,noheader"}, MAX_EXEC_TIMEOUT_MS);
}
public String getMajorMinorInfo(String devName) throws IOException {
// output "major:minor" in hex
Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor(
new String[]{"stat", "-c", "%t:%T", "/dev/" + devName});
shexec.execute();
return shexec.getOutput();
}
public void searchBinary() throws Exception {
if (pathOfGpuBinary != null) {
LOG.info("Skip searching, the nvidia gpu binary is already set: "
+ pathOfGpuBinary);
return;
}
// search env for the binary
String envBinaryPath = System.getenv(ENV_BINARY_PATH);
if (null != envBinaryPath) {
if (new File(envBinaryPath).exists()) {
pathOfGpuBinary = envBinaryPath;
LOG.info("Use nvidia gpu binary: " + pathOfGpuBinary);
return;
}
}
LOG.info("Search binary..");
// search if binary exists in default folders
File binaryFile;
boolean found = false;
for (String dir : DEFAULT_BINARY_SEARCH_DIRS) {
binaryFile = new File(dir, DEFAULT_BINARY_NAME);
if (binaryFile.exists()) {
found = true;
pathOfGpuBinary = binaryFile.getAbsolutePath();
LOG.info("Found binary:" + pathOfGpuBinary);
break;
}
}
if (!found) {
LOG.error("No binary found from env variable: "
+ ENV_BINARY_PATH + " or path "
+ DEFAULT_BINARY_SEARCH_DIRS.toString());
throw new Exception("No binary found for "
+ NvidiaGPUPluginForRuntimeV2.class);
}
}
}
@VisibleForTesting
public void setPathOfGpuBinary(String pathOfGpuBinary) {
this.pathOfGpuBinary = pathOfGpuBinary;
}
@VisibleForTesting
public void setShellExecutor(
NvidiaCommandExecutor shellExecutor) {
this.shellExecutor = shellExecutor;
}
}

View File

@ -0,0 +1,19 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;

View File

@ -95,6 +95,20 @@ public class DeviceMappingManager {
return devicePluginSchedulers;
}
@VisibleForTesting
public Set<Device> getAllocatedDevices(String resourceName,
ContainerId cId) {
Set<Device> assigned = new TreeSet<>();
Map<Device, ContainerId> assignedMap =
this.getAllUsedDevices().get(resourceName);
for (Map.Entry<Device, ContainerId> entry : assignedMap.entrySet()) {
if (entry.getValue().equals(cId)) {
assigned.add(entry.getKey());
}
}
return assigned;
}
public synchronized void addDeviceSet(String resourceName,
Set<Device> deviceSet) {
LOG.info("Adding new resource: " + "type:"
@ -148,8 +162,10 @@ public class DeviceMappingManager {
ContainerId containerId = container.getContainerId();
int requestedDeviceCount = getRequestedDeviceCount(resourceName,
requestedResource);
LOG.debug("Try allocating " + requestedDeviceCount
+ " " + resourceName);
if (LOG.isDebugEnabled()) {
LOG.debug("Try allocating " + requestedDeviceCount
+ " " + resourceName);
}
// Assign devices to container if requested some.
if (requestedDeviceCount > 0) {
if (requestedDeviceCount > getAvailableDevices(resourceName)) {
@ -245,18 +261,24 @@ public class DeviceMappingManager {
ContainerId containerId) {
Iterator<Map.Entry<Device, ContainerId>> iter =
allUsedDevices.get(resourceName).entrySet().iterator();
Map.Entry<Device, ContainerId> entry;
while (iter.hasNext()) {
if (iter.next().getValue().equals(containerId)) {
entry = iter.next();
if (entry.getValue().equals(containerId)) {
if (LOG.isDebugEnabled()) {
LOG.debug("Recycle devices: " + entry.getKey()
+ ", type: " + resourceName + " from " + containerId);
}
iter.remove();
}
}
}
public static int getRequestedDeviceCount(String resourceName,
public static int getRequestedDeviceCount(String resName,
Resource requestedResource) {
try {
return Long.valueOf(requestedResource.getResourceValue(
resourceName)).intValue();
resName)).intValue();
} catch (ResourceNotFoundException e) {
return 0;
}
@ -270,10 +292,7 @@ public class DeviceMappingManager {
private long getReleasingDevices(String resourceName) {
long releasingDevices = 0;
Map<Device, ContainerId> used = allUsedDevices.get(resourceName);
Iterator<Map.Entry<Device, ContainerId>> iter = used.entrySet()
.iterator();
while (iter.hasNext()) {
ContainerId containerId = iter.next().getValue();
for (ContainerId containerId : ImmutableSet.copyOf(used.values())) {
Container container = nmContext.getContainers().get(containerId);
if (container != null) {
if (container.isContainerInFinalStates()) {
@ -295,16 +314,20 @@ public class DeviceMappingManager {
DevicePluginScheduler dps) throws ResourceHandlerException {
if (null == dps) {
LOG.debug("Customized device plugin scheduler is preferred "
+ "but not implemented, use default logic");
if (LOG.isDebugEnabled()) {
LOG.debug("Customized device plugin scheduler is preferred "
+ "but not implemented, use default logic");
}
defaultScheduleAction(allowed, used,
assigned, containerId, count);
} else {
LOG.debug("Customized device plugin implemented,"
+ "use customized logic");
// Use customized device scheduler
LOG.debug("Try to schedule " + count
+ "(" + resourceName + ") using " + dps.getClass());
if (LOG.isDebugEnabled()) {
LOG.debug("Customized device plugin implemented,"
+ "use customized logic");
// Use customized device scheduler
LOG.debug("Try to schedule " + count
+ "(" + resourceName + ") using " + dps.getClass());
}
// Pass in unmodifiable set
Set<Device> dpsAllocated = dps.allocateDevices(
Sets.difference(allowed, used.keySet()),
@ -345,6 +368,7 @@ public class DeviceMappingManager {
private String resourceName;
private Set<Device> allowed = Collections.emptySet();
private Set<Device> denied = Collections.emptySet();
DeviceAllocation(String resName, Set<Device> a,
@ -362,6 +386,10 @@ public class DeviceMappingManager {
return allowed;
}
public Set<Device> getDenied() {
return denied;
}
@Override
public String toString() {
return "ResourceType: " + resourceName

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.yarn.api.records.ContainerId;
@ -49,10 +50,20 @@ public class DevicePluginAdapter implements ResourcePlugin {
private final static Log LOG = LogFactory.getLog(DevicePluginAdapter.class);
private final String resourceName;
private final DevicePlugin devicePlugin;
private DeviceMappingManager deviceMappingManager;
private DeviceResourceHandlerImpl deviceResourceHandler;
private DeviceResourceUpdaterImpl deviceResourceUpdater;
private DeviceResourceDockerRuntimePluginImpl deviceDockerCommandPlugin;
@VisibleForTesting
public void setDeviceResourceHandler(
DeviceResourceHandlerImpl deviceResourceHandler) {
this.deviceResourceHandler = deviceResourceHandler;
}
public DevicePluginAdapter(String name, DevicePlugin dp,
DeviceMappingManager dmm) {
@ -65,8 +76,16 @@ public class DevicePluginAdapter implements ResourcePlugin {
return deviceMappingManager;
}
public DevicePlugin getDevicePlugin() {
return devicePlugin;
}
@Override
public void initialize(Context context) throws YarnException {
deviceDockerCommandPlugin = new DeviceResourceDockerRuntimePluginImpl(
resourceName,
devicePlugin, this);
deviceResourceUpdater = new DeviceResourceUpdaterImpl(
resourceName, devicePlugin);
LOG.info(resourceName + " plugin adapter initialized");
@ -78,8 +97,8 @@ public class DevicePluginAdapter implements ResourcePlugin {
CGroupsHandler cGroupsHandler,
PrivilegedOperationExecutor privilegedOperationExecutor) {
this.deviceResourceHandler = new DeviceResourceHandlerImpl(resourceName,
devicePlugin, this, deviceMappingManager,
cGroupsHandler, privilegedOperationExecutor);
this, deviceMappingManager,
cGroupsHandler, privilegedOperationExecutor, nmContext);
return deviceResourceHandler;
}
@ -95,7 +114,7 @@ public class DevicePluginAdapter implements ResourcePlugin {
@Override
public DockerCommandPlugin getDockerCommandPluginInstance() {
return null;
return deviceDockerCommandPlugin;
}
@Override

View File

@ -0,0 +1,233 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountDeviceSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountVolumeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.VolumeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerExecutionException;
import org.apache.hadoop.yarn.util.LRUCacheHashMap;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
/**
* Bridge DevicePlugin and the hooks related to lunch Docker container.
* When launching Docker container, DockerLinuxContainerRuntime will invoke
* this class's methods which get needed info back from DevicePlugin.
* */
public class DeviceResourceDockerRuntimePluginImpl
implements DockerCommandPlugin {
final static Log LOG = LogFactory.getLog(
DeviceResourceDockerRuntimePluginImpl.class);
private String resourceName;
private DevicePlugin devicePlugin;
private DevicePluginAdapter devicePluginAdapter;
private int maxCacheSize = 100;
// LRU to avoid memory leak if getCleanupDockerVolumesCommand not invoked.
private Map<ContainerId, Set<Device>> cachedAllocation =
Collections.synchronizedMap(new LRUCacheHashMap(maxCacheSize, true));
private Map<ContainerId, DeviceRuntimeSpec> cachedSpec =
Collections.synchronizedMap(new LRUCacheHashMap<>(maxCacheSize, true));
public DeviceResourceDockerRuntimePluginImpl(String resourceName,
DevicePlugin devicePlugin, DevicePluginAdapter devicePluginAdapter) {
this.resourceName = resourceName;
this.devicePlugin = devicePlugin;
this.devicePluginAdapter = devicePluginAdapter;
}
@Override
public void updateDockerRunCommand(DockerRunCommand dockerRunCommand,
Container container) throws ContainerExecutionException {
String containerId = container.getContainerId().toString();
if (LOG.isDebugEnabled()) {
LOG.debug("Try to update docker run command for: " + containerId);
}
if(!requestedDevice(resourceName, container)) {
return;
}
DeviceRuntimeSpec deviceRuntimeSpec = getRuntimeSpec(container);
if (deviceRuntimeSpec == null) {
LOG.warn("The device plugin: "
+ devicePlugin.getClass().getCanonicalName()
+ " returns null device runtime spec value for container: "
+ containerId);
return;
}
// handle runtime
dockerRunCommand.addRuntime(deviceRuntimeSpec.getContainerRuntime());
if (LOG.isDebugEnabled()) {
LOG.debug("Handle docker container runtime type: "
+ deviceRuntimeSpec.getContainerRuntime() + " for container: "
+ containerId);
}
// handle device mounts
Set<MountDeviceSpec> deviceMounts = deviceRuntimeSpec.getDeviceMounts();
if (LOG.isDebugEnabled()) {
LOG.debug("Handle device mounts: " + deviceMounts + " for container: "
+ containerId);
}
for (MountDeviceSpec mountDeviceSpec : deviceMounts) {
dockerRunCommand.addDevice(
mountDeviceSpec.getDevicePathInHost(),
mountDeviceSpec.getDevicePathInContainer());
}
// handle volume mounts
Set<MountVolumeSpec> mountVolumeSpecs = deviceRuntimeSpec.getVolumeMounts();
if (LOG.isDebugEnabled()) {
LOG.debug("Handle volume mounts: " + mountVolumeSpecs + " for container: "
+ containerId);
}
for (MountVolumeSpec mountVolumeSpec : mountVolumeSpecs) {
if (mountVolumeSpec.getReadOnly()) {
dockerRunCommand.addReadOnlyMountLocation(
mountVolumeSpec.getHostPath(),
mountVolumeSpec.getMountPath());
} else {
dockerRunCommand.addReadWriteMountLocation(
mountVolumeSpec.getHostPath(),
mountVolumeSpec.getMountPath());
}
}
// handle envs
dockerRunCommand.addEnv(deviceRuntimeSpec.getEnvs());
if (LOG.isDebugEnabled()) {
LOG.debug("Handle envs: " + deviceRuntimeSpec.getEnvs()
+ " for container: " + containerId);
}
}
@Override
public DockerVolumeCommand getCreateDockerVolumeCommand(Container container)
throws ContainerExecutionException {
if(!requestedDevice(resourceName, container)) {
return null;
}
DeviceRuntimeSpec deviceRuntimeSpec = getRuntimeSpec(container);
if (deviceRuntimeSpec == null) {
return null;
}
Set<VolumeSpec> volumeClaims = deviceRuntimeSpec.getVolumeSpecs();
for (VolumeSpec volumeSec: volumeClaims) {
if (volumeSec.getVolumeOperation().equals(VolumeSpec.CREATE)) {
DockerVolumeCommand command = new DockerVolumeCommand(
DockerVolumeCommand.VOLUME_CREATE_SUB_COMMAND);
command.setDriverName(volumeSec.getVolumeDriver());
command.setVolumeName(volumeSec.getVolumeName());
if (LOG.isDebugEnabled()) {
LOG.debug("Get volume create request from plugin:" + volumeClaims
+ " for container: " + container.getContainerId().toString());
}
return command;
}
}
return null;
}
@Override
public DockerVolumeCommand getCleanupDockerVolumesCommand(Container container)
throws ContainerExecutionException {
if(!requestedDevice(resourceName, container)) {
return null;
}
Set<Device> allocated = getAllocatedDevices(container);
try {
devicePlugin.onDevicesReleased(allocated);
} catch (Exception e) {
LOG.warn("Exception thrown in onDeviceReleased of "
+ devicePlugin.getClass() + "for container: "
+ container.getContainerId().toString(), e);
}
// remove cache
ContainerId containerId = container.getContainerId();
cachedAllocation.remove(containerId);
cachedSpec.remove(containerId);
return null;
}
protected boolean requestedDevice(String resName, Container container) {
return DeviceMappingManager.
getRequestedDeviceCount(resName, container.getResource()) > 0;
}
private Set<Device> getAllocatedDevices(Container container) {
// get allocated devices
Set<Device> allocated;
ContainerId containerId = container.getContainerId();
allocated = cachedAllocation.get(containerId);
if (allocated != null) {
return allocated;
}
allocated = devicePluginAdapter
.getDeviceMappingManager()
.getAllocatedDevices(resourceName, containerId);
if (LOG.isDebugEnabled()) {
LOG.debug("Get allocation from deviceMappingManager: "
+ allocated + ", " + resourceName + " for container: " + containerId);
}
cachedAllocation.put(containerId, allocated);
return allocated;
}
public synchronized DeviceRuntimeSpec getRuntimeSpec(Container container) {
ContainerId containerId = container.getContainerId();
DeviceRuntimeSpec deviceRuntimeSpec = cachedSpec.get(containerId);
if (deviceRuntimeSpec == null) {
Set<Device> allocated = getAllocatedDevices(container);
if (allocated == null || allocated.size() == 0) {
LOG.error("Cannot get allocation for container:" + containerId);
return null;
}
try {
deviceRuntimeSpec = devicePlugin.onDevicesAllocated(allocated,
YarnRuntimeType.RUNTIME_DOCKER);
} catch (Exception e) {
LOG.error("Exception thrown in onDeviceAllocated of "
+ devicePlugin.getClass() + " for container: " + containerId, e);
}
if (deviceRuntimeSpec == null) {
LOG.error("Null DeviceRuntimeSpec value got from "
+ devicePlugin.getClass() + " for container: "
+ containerId + ", please check plugin logic");
return null;
}
cachedSpec.put(containerId, deviceRuntimeSpec);
}
return deviceRuntimeSpec;
}
}

View File

@ -18,20 +18,29 @@
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.DockerLinuxContainerRuntime;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
@ -52,19 +61,45 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
private final CGroupsHandler cGroupsHandler;
private final PrivilegedOperationExecutor privilegedOperationExecutor;
private final DevicePluginAdapter devicePluginAdapter;
private final Context nmContext;
private ShellWrapper shellWrapper;
public DeviceResourceHandlerImpl(String reseName,
DevicePlugin devPlugin,
// This will be used by container-executor to add necessary clis
public static final String EXCLUDED_DEVICES_CLI_OPTION = "--excluded_devices";
public static final String ALLOWED_DEVICES_CLI_OPTION = "--allowed_devices";
public static final String CONTAINER_ID_CLI_OPTION = "--container_id";
public DeviceResourceHandlerImpl(String resName,
DevicePluginAdapter devPluginAdapter,
DeviceMappingManager devMappingManager,
CGroupsHandler cgHandler,
PrivilegedOperationExecutor operation) {
PrivilegedOperationExecutor operation,
Context ctx) {
this.devicePluginAdapter = devPluginAdapter;
this.resourceName = reseName;
this.devicePlugin = devPlugin;
this.resourceName = resName;
this.devicePlugin = devPluginAdapter.getDevicePlugin();
this.cGroupsHandler = cgHandler;
this.privilegedOperationExecutor = operation;
this.deviceMappingManager = devMappingManager;
this.nmContext = ctx;
this.shellWrapper = new ShellWrapper();
}
@VisibleForTesting
public DeviceResourceHandlerImpl(String resName,
DevicePluginAdapter devPluginAdapter,
DeviceMappingManager devMappingManager,
CGroupsHandler cgHandler,
PrivilegedOperationExecutor operation,
Context ctx, ShellWrapper shell) {
this.devicePluginAdapter = devPluginAdapter;
this.resourceName = resName;
this.devicePlugin = devPluginAdapter.getDevicePlugin();
this.cGroupsHandler = cgHandler;
this.privilegedOperationExecutor = operation;
this.deviceMappingManager = devMappingManager;
this.nmContext = ctx;
this.shellWrapper = shell;
}
@Override
@ -98,11 +133,13 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
String containerIdStr = container.getContainerId().toString();
DeviceMappingManager.DeviceAllocation allocation =
deviceMappingManager.assignDevices(resourceName, container);
LOG.debug("Allocated to "
+ containerIdStr + ": " + allocation);
if (LOG.isDebugEnabled()) {
LOG.debug("Allocated to "
+ containerIdStr + ": " + allocation);
}
DeviceRuntimeSpec spec;
try {
devicePlugin.onDevicesAllocated(
spec = devicePlugin.onDevicesAllocated(
allocation.getAllowed(), YarnRuntimeType.RUNTIME_DEFAULT);
} catch (Exception e) {
throw new ResourceHandlerException("Exception thrown from"
@ -110,13 +147,95 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
}
// cgroups operation based on allocation
/**
* TODO: implement a general container-executor device module
* */
if (spec != null) {
LOG.warn("Runtime spec in non-Docker container is not supported yet!");
}
// Create device cgroups for the container
cGroupsHandler.createCGroup(CGroupsHandler.CGroupController.DEVICES,
containerIdStr);
// non-Docker, use cgroups to do isolation
if (!DockerLinuxContainerRuntime.isDockerContainerRequested(
nmContext.getConf(),
container.getLaunchContext().getEnvironment())) {
tryIsolateDevices(allocation, containerIdStr);
List<PrivilegedOperation> ret = new ArrayList<>();
ret.add(new PrivilegedOperation(
PrivilegedOperation.OperationType.ADD_PID_TO_CGROUP,
PrivilegedOperation.CGROUP_ARG_PREFIX + cGroupsHandler
.getPathForCGroupTasks(CGroupsHandler.CGroupController.DEVICES,
containerIdStr)));
return ret;
}
return null;
}
/**
* Try set cgroup devices params for the container using container-executor.
* If it has real device major number, minor number or dev path,
* we'll do the enforcement. Otherwise, won't do it.
*
* */
private void tryIsolateDevices(
DeviceMappingManager.DeviceAllocation allocation,
String containerIdStr) throws ResourceHandlerException {
try {
// Execute c-e to setup device isolation before launch the container
PrivilegedOperation privilegedOperation = new PrivilegedOperation(
PrivilegedOperation.OperationType.DEVICE,
Arrays.asList(CONTAINER_ID_CLI_OPTION, containerIdStr));
boolean needNativeDeviceOperation = false;
int majorNumber;
int minorNumber;
List<String> devNumbers = new ArrayList<>();
if (!allocation.getDenied().isEmpty()) {
DeviceType devType;
for (Device deniedDevice : allocation.getDenied()) {
majorNumber = deniedDevice.getMajorNumber();
minorNumber = deniedDevice.getMinorNumber();
// Add device type
devType = getDeviceType(deniedDevice);
if (devType != null) {
devNumbers.add(devType.getName() + "-" + majorNumber + ":"
+ minorNumber + "-rwm");
}
}
if (devNumbers.size() != 0) {
privilegedOperation.appendArgs(
Arrays.asList(EXCLUDED_DEVICES_CLI_OPTION,
StringUtils.join(",", devNumbers)));
needNativeDeviceOperation = true;
}
}
if (!allocation.getAllowed().isEmpty()) {
devNumbers.clear();
for (Device allowedDevice : allocation.getAllowed()) {
majorNumber = allowedDevice.getMajorNumber();
minorNumber = allowedDevice.getMinorNumber();
if (majorNumber != -1 && minorNumber != -1) {
devNumbers.add(majorNumber + ":" + minorNumber);
}
}
if (devNumbers.size() > 0) {
privilegedOperation.appendArgs(
Arrays.asList(ALLOWED_DEVICES_CLI_OPTION,
StringUtils.join(",", devNumbers)));
needNativeDeviceOperation = true;
}
}
if (needNativeDeviceOperation) {
privilegedOperationExecutor.executePrivilegedOperation(
privilegedOperation, true);
}
} catch (PrivilegedOperationException e) {
cGroupsHandler.deleteCGroup(CGroupsHandler.CGroupController.DEVICES,
containerIdStr);
LOG.warn("Could not update cgroup for container", e);
throw new ResourceHandlerException(e);
}
}
@Override
public synchronized List<PrivilegedOperation> reacquireContainer(
ContainerId containerId) throws ResourceHandlerException {
@ -134,6 +253,8 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
public synchronized List<PrivilegedOperation> postComplete(
ContainerId containerId) throws ResourceHandlerException {
deviceMappingManager.cleanupAssignedDevices(resourceName, containerId);
cGroupsHandler.deleteCGroup(CGroupsHandler.CGroupController.DEVICES,
containerId.toString());
return null;
}
@ -151,4 +272,73 @@ public class DeviceResourceHandlerImpl implements ResourceHandler {
", devicePluginAdapter=" + devicePluginAdapter +
'}';
}
public DeviceType getDeviceType(Device device) {
String devName = device.getDevPath();
if (devName.isEmpty()) {
LOG.warn("Empty device path provided, try to get device type from " +
"major:minor device number");
int major = device.getMajorNumber();
int minor = device.getMinorNumber();
if (major == -1 && minor == -1) {
LOG.warn("Non device number provided, cannot decide the device type");
return null;
}
// Get type from the device numbers
return getDeviceTypeFromDeviceNumber(device.getMajorNumber(),
device.getMinorNumber());
}
DeviceType deviceType;
try {
if (LOG.isDebugEnabled()) {
LOG.debug("Try to get device type from device path: " + devName);
}
String output = shellWrapper.getDeviceFileType(devName);
if (LOG.isDebugEnabled()) {
LOG.debug("stat output:" + output);
}
deviceType = output.startsWith("c") ? DeviceType.CHAR : DeviceType.BLOCK;
} catch (IOException e) {
String msg =
"Failed to get device type from stat " + devName;
LOG.warn(msg);
return null;
}
return deviceType;
}
/**
* Get the device type used for cgroups value set.
* If sys file "/sys/dev/block/major:minor" exists, it's block device.
* Otherwise, it's char device. An exception is that Nvidia GPU doesn't
* create this sys file. so assume character device by default.
*/
public DeviceType getDeviceTypeFromDeviceNumber(int major, int minor) {
if (shellWrapper.existFile("/sys/dev/block/"
+ major + ":" + minor)) {
return DeviceType.BLOCK;
}
return DeviceType.CHAR;
}
/**
* Enum for Linux device type. Used when updating device cgroups params.
* "b" represents block device
* "c" represents character device
* */
private enum DeviceType {
BLOCK("b"),
CHAR("c");
private final String name;
DeviceType(String n) {
this.name = n;
}
public String getName() {
return name;
}
}
}

View File

@ -0,0 +1,46 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
import org.apache.hadoop.util.Shell;
import java.io.File;
import java.io.IOException;
/**
* A shell Wrapper to ease testing.
* */
public class ShellWrapper {
public String getDeviceFileType(String devName) throws IOException {
Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor(
new String[]{"stat", "-c", "%F", devName});
shexec.execute();
return shexec.getOutput();
}
public boolean existFile(String path) {
File searchFile =
new File(path);
if (searchFile.exists()) {
return true;
}
return false;
}
}

View File

@ -24,6 +24,7 @@
#include "modules/gpu/gpu-module.h"
#include "modules/fpga/fpga-module.h"
#include "modules/cgroups/cgroups-operations.h"
#include "modules/devices/devices-module.h"
#include "utils/string-utils.h"
#include <errno.h>
@ -289,6 +290,11 @@ static int validate_arguments(int argc, char **argv , int *operation) {
&argv[1]);
}
if (strcmp("--module-devices", argv[1]) == 0) {
return handle_devices_request(&update_cgroups_parameters, "devices", argc - 1,
&argv[1]);
}
if (strcmp("--checksetup", argv[1]) == 0) {
*operation = CHECK_SETUP;
return 0;

View File

@ -132,7 +132,7 @@ int update_cgroups_parameters(
goto cleanup;
}
fprintf(ERRORFILE, "CGroups: Updating cgroups, path=%s, value=%s",
fprintf(ERRORFILE, "CGroups: Updating cgroups, path=%s, value=%s\n",
full_path, value);
// Write values to file

View File

@ -0,0 +1,281 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "configuration.h"
#include "container-executor.h"
#include "utils/string-utils.h"
#include "modules/devices/devices-module.h"
#include "modules/cgroups/cgroups-operations.h"
#include "modules/common/module-configs.h"
#include "modules/common/constants.h"
#include "util.h"
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <getopt.h>
#include <unistd.h>
#include <sys/stat.h>
#define EXCLUDED_DEVICES_OPTION "excluded_devices"
#define ALLOWED_DEVICES_OPTION "allowed_devices"
#define CONTAINER_ID_OPTION "container_id"
#define MAX_CONTAINER_ID_LEN 128
static const struct section* cfg_section;
// Search a string in a string list, return 1 when found
static int search_in_list(char** list, char* token) {
int i = 0;
char** iterator = list;
// search token in list
while (iterator[i] != NULL) {
if (strstr(token, iterator[i]) != NULL ||
strstr(iterator[i], token) != NULL) {
// Found deny device in allowed list
return 1;
}
i++;
}
return 0;
}
static int is_block_device(const char* value) {
int is_block = 0;
int max_path_size = 512;
char* block_path = malloc(max_path_size);
if (block_path == NULL) {
fprintf(ERRORFILE, "Failed to allocate memory for sys device path string.\n");
fflush(ERRORFILE);
goto cleanup;
}
if (snprintf(block_path, max_path_size, "/sys/dev/block/%s",
value) < 0) {
fprintf(ERRORFILE, "Failed to construct system block device path.\n");
goto cleanup;
}
struct stat sb;
// file exists, is block device
if (stat(block_path, &sb) == 0) {
is_block = 1;
}
cleanup:
if (block_path) {
free(block_path);
}
return is_block;
}
static int internal_handle_devices_request(
update_cgroups_parameters_function update_cgroups_parameters_func_p,
char** deny_devices_number_tokens,
char** allow_devices_number_tokens,
const char* container_id) {
int return_code = 0;
char** ce_denied_numbers = NULL;
char* ce_denied_str = get_section_value(DEVICES_DENIED_NUMBERS,
cfg_section);
// Get denied "major:minor" device numbers from cfg, if not set, means all
// devices can be used by YARN.
if (ce_denied_str != NULL) {
ce_denied_numbers = split_delimiter(ce_denied_str, ",");
if (NULL == ce_denied_numbers) {
fprintf(ERRORFILE,
"Invalid value set for %s, value=%s\n",
DEVICES_DENIED_NUMBERS,
ce_denied_str);
return_code = -1;
goto cleanup;
}
// Check allowed devices passed in
char** allow_iterator = allow_devices_number_tokens;
int allow_count = 0;
while (allow_iterator[allow_count] != NULL) {
if (search_in_list(ce_denied_numbers, allow_iterator[allow_count])) {
fprintf(ERRORFILE,
"Trying to allow device with device number=%s which is not permitted in container-executor.cfg. %s\n",
allow_iterator[allow_count],
"It could be caused by a mismatch of devices reported by device plugin");
return_code = -1;
goto cleanup;
}
allow_count++;
}
// Deny devices configured in c-e.cfg
char** ce_iterator = ce_denied_numbers;
int ce_count = 0;
while (ce_iterator[ce_count] != NULL) {
// skip if duplicate with denied numbers passed in
if (search_in_list(deny_devices_number_tokens, ce_iterator[ce_count])) {
ce_count++;
continue;
}
char param_value[128];
char type = 'c';
memset(param_value, 0, sizeof(param_value));
if (is_block_device(ce_iterator[ce_count])) {
type = 'b';
}
snprintf(param_value, sizeof(param_value), "%c %s rwm",
type,
ce_iterator[ce_count]);
// Update device cgroups value
int rc = update_cgroups_parameters_func_p("devices", "deny",
container_id, param_value);
if (0 != rc) {
fprintf(ERRORFILE, "CGroups: Failed to update cgroups. %s\n", param_value);
return_code = -1;
goto cleanup;
}
ce_count++;
}
}
// Deny devices passed from java side
char** iterator = deny_devices_number_tokens;
int count = 0;
char* value = NULL;
int index = 0;
while (iterator[count] != NULL) {
// Replace like "c-242:0-rwm" to "c 242:0 rwm"
value = iterator[count];
index = 0;
while (value[index] != '\0') {
if (value[index] == '-') {
value[index] = ' ';
}
index++;
}
// Update device cgroups value
int rc = update_cgroups_parameters_func_p("devices", "deny",
container_id, iterator[count]);
if (0 != rc) {
fprintf(ERRORFILE, "CGroups: Failed to update cgroups\n");
return_code = -1;
goto cleanup;
}
count++;
}
cleanup:
if (ce_denied_numbers != NULL) {
free_values(ce_denied_numbers);
}
return return_code;
}
void reload_devices_configuration() {
cfg_section = get_configuration_section(DEVICES_MODULE_SECTION_NAME, get_cfg());
}
/*
* Format of devices request commandline:
* The excluded_devices is comma separated device cgroups values with device type.
* The "-" will be replaced with " " to match the cgroups parameter
* c-e --module-devices \
* --excluded_devices b-8:16-rwm,c-244:0-rwm,c-244:1-rwm \
* --allowed_devices 8:32,8:48,243:2 \
* --container_id container_x_y
*/
int handle_devices_request(update_cgroups_parameters_function func,
const char* module_name, int module_argc, char** module_argv) {
if (!cfg_section) {
reload_devices_configuration();
}
if (!module_enabled(cfg_section, DEVICES_MODULE_SECTION_NAME)) {
fprintf(ERRORFILE,
"Please make sure devices module is enabled before using it.\n");
return -1;
}
static struct option long_options[] = {
{EXCLUDED_DEVICES_OPTION, required_argument, 0, 'e' },
{ALLOWED_DEVICES_OPTION, required_argument, 0, 'a' },
{CONTAINER_ID_OPTION, required_argument, 0, 'c' },
{0, 0, 0, 0}
};
int c = 0;
int option_index = 0;
char** deny_device_value_tokens = NULL;
char** allow_device_value_tokens = NULL;
char container_id[MAX_CONTAINER_ID_LEN];
memset(container_id, 0, sizeof(container_id));
int failed = 0;
optind = 1;
while((c = getopt_long(module_argc, module_argv, "e:a:c:",
long_options, &option_index)) != -1) {
switch(c) {
case 'e':
deny_device_value_tokens = split_delimiter(optarg, ",");
break;
case 'a':
allow_device_value_tokens = split_delimiter(optarg, ",");
break;
case 'c':
if (!validate_container_id(optarg)) {
fprintf(ERRORFILE,
"Specified container_id=%s is invalid\n", optarg);
failed = 1;
goto cleanup;
}
strncpy(container_id, optarg, MAX_CONTAINER_ID_LEN);
break;
default:
fprintf(ERRORFILE,
"Unknown option in devices command character %d %c, optionindex = %d\n",
c, c, optind);
failed = 1;
goto cleanup;
}
}
if (0 == container_id[0]) {
fprintf(ERRORFILE,
"[%s] --container_id must be specified.\n", __func__);
failed = 1;
goto cleanup;
}
if (NULL == deny_device_value_tokens) {
// Devices number is null, skip following call.
fprintf(ERRORFILE, "--excluded_devices is not specified, skip cgroups call.\n");
goto cleanup;
}
failed = internal_handle_devices_request(func,
deny_device_value_tokens,
allow_device_value_tokens,
container_id);
cleanup:
if (deny_device_value_tokens) {
free_values(deny_device_value_tokens);
}
if (allow_device_value_tokens) {
free_values(allow_device_value_tokens);
}
return failed;
}

View File

@ -0,0 +1,45 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef __FreeBSD__
#define _WITH_GETLINE
#endif
#ifndef _MODULES_DEVICES_MUDULE_H_
#define _MODULES_DEVICES_MUDULE_H_
// Denied device list. value format is "major1:minor1,major2:minor2"
#define DEVICES_DENIED_NUMBERS "devices.denied-numbers"
#define DEVICES_MODULE_SECTION_NAME "devices"
// For unit test stubbing
typedef int (*update_cgroups_parameters_function)(const char*, const char*,
const char*, const char*);
/**
* Handle devices requests
*/
int handle_devices_request(update_cgroups_parameters_function func,
const char* module_name, int module_argc, char** module_argv);
/**
* Reload config from filesystem, visible for testing.
*/
void reload_devices_configuration();
#endif

View File

@ -44,6 +44,9 @@ char** split_delimiter(char *value, const char *delim) {
memset(return_values, 0, sizeof(char *) * return_values_size);
temp_tok = strtok_r(value, delim, &tempstr);
if (NULL == temp_tok) {
return_values[size++] = strdup(value);
}
while (temp_tok != NULL) {
temp_tok = strdup(temp_tok);
if (NULL == temp_tok) {

View File

@ -0,0 +1,298 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <unistd.h>
#include <gtest/gtest.h>
#include <sstream>
extern "C" {
#include "configuration.h"
#include "container-executor.h"
#include "modules/cgroups/cgroups-operations.h"
#include "modules/devices/devices-module.h"
#include "test/test-container-executor-common.h"
#include "util.h"
}
namespace ContainerExecutor {
class TestDevicesModule : public ::testing::Test {
protected:
virtual void SetUp() {
if (mkdirs(TEST_ROOT, 0755) != 0) {
fprintf(ERRORFILE, "Failed to mkdir TEST_ROOT: %s\n", TEST_ROOT);
exit(1);
}
LOGFILE = stdout;
ERRORFILE = stderr;
}
virtual void TearDown() {
}
};
static std::vector<const char*> cgroups_parameters_invoked;
static int mock_update_cgroups_parameters(
const char* controller_name,
const char* param_name,
const char* group_id,
const char* value) {
char* buf = (char*) malloc(128);
strcpy(buf, controller_name);
cgroups_parameters_invoked.push_back(buf);
buf = (char*) malloc(128);
strcpy(buf, param_name);
cgroups_parameters_invoked.push_back(buf);
buf = (char*) malloc(128);
strcpy(buf, group_id);
cgroups_parameters_invoked.push_back(buf);
buf = (char*) malloc(128);
strcpy(buf, value);
cgroups_parameters_invoked.push_back(buf);
return 0;
}
static void clear_cgroups_parameters_invoked() {
for (std::vector<const char*>::size_type i = 0; i < cgroups_parameters_invoked.size(); i++) {
free((void *) cgroups_parameters_invoked[i]);
}
cgroups_parameters_invoked.clear();
}
static void verify_param_updated_to_cgroups(
int argc, const char** argv) {
ASSERT_EQ(argc, cgroups_parameters_invoked.size());
int offset = 0;
while (offset < argc) {
ASSERT_STREQ(argv[offset], cgroups_parameters_invoked[offset]);
offset++;
}
}
static void write_and_load_devices_module_to_cfg(const char* cfg_filepath, int enabled) {
FILE *file = fopen(cfg_filepath, "w");
if (file == NULL) {
printf("FAIL: Could not open configuration file: %s\n", cfg_filepath);
exit(1);
}
fprintf(file, "[devices]\n");
if (enabled) {
fprintf(file, "module.enabled=true\n");
} else {
fprintf(file, "module.enabled=false\n");
}
fclose(file);
// Read config file
read_executor_config(cfg_filepath);
reload_devices_configuration();
}
static void append_config(const char* cfg_filepath, char values[]) {
FILE *file = fopen(cfg_filepath, "a");
if (file == NULL) {
printf("FAIL: Could not open configuration file: %s\n", cfg_filepath);
exit(1);
}
fprintf(file, "%s", values);
fclose(file);
// Read config file
read_executor_config(cfg_filepath);
reload_devices_configuration();
}
static void test_devices_module_enabled_disabled(int enabled) {
// Write config file.
const char *filename = TEST_ROOT "/test_cgroups_module_enabled_disabled.cfg";
write_and_load_devices_module_to_cfg(filename, enabled);
char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
char allowed_devices[] = "243:2";
char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
excluded_devices,
(char*) "--allowed_devices",
allowed_devices,
(char*) "--container_id",
(char*) "container_1498064906505_0001_01_000001" };
int rc = handle_devices_request(&mock_update_cgroups_parameters,
"devices", 7, argv);
int EXPECTED_RC;
if (enabled) {
EXPECTED_RC = 0;
} else {
EXPECTED_RC = -1;
}
ASSERT_EQ(EXPECTED_RC, rc);
clear_cgroups_parameters_invoked();
free_executor_configurations();
}
TEST_F(TestDevicesModule, test_verify_device_module_calls_cgroup_parameter) {
// Write config file.
const char *filename = TEST_ROOT "/test_verify_devices_module_calls_cgroup_parameter.cfg";
write_and_load_devices_module_to_cfg(filename, 1);
char* container_id = (char*) "container_1498064906505_0001_01_000001";
char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
char allowed_devices[] = "243:2";
char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
excluded_devices,
(char*) "--allowed_devices",
allowed_devices,
(char*) "--container_id",
container_id };
/* Test case 1: block 2 devices */
clear_cgroups_parameters_invoked();
int rc = handle_devices_request(&mock_update_cgroups_parameters,
"devices", 7, argv);
ASSERT_EQ(0, rc) << "Should success.\n";
// Verify cgroups parameters
const char* expected_cgroups_argv[] = { "devices", "deny", container_id, "c 243:0 rwm",
"devices", "deny", container_id, "c 243:1 rwm"};
verify_param_updated_to_cgroups(8, expected_cgroups_argv);
/* Test case 2: block 0 devices */
clear_cgroups_parameters_invoked();
char* argv_1[] = { (char*) "--module-devices", (char*) "--container_id", container_id };
rc = handle_devices_request(&mock_update_cgroups_parameters,
"devices", 3, argv_1);
ASSERT_EQ(0, rc) << "Should success.\n";
// Verify cgroups parameters
verify_param_updated_to_cgroups(0, NULL);
clear_cgroups_parameters_invoked();
free_executor_configurations();
}
TEST_F(TestDevicesModule, test_update_cgroup_parameter_with_config) {
// Write config file.
const char *filename = TEST_ROOT "/test_update_cgroup_parameter_with_config.cfg";
write_and_load_devices_module_to_cfg(filename, 1);
// Add denied numbers
char tokens[] = "devices.denied-numbers=243:1\n";
append_config(filename, tokens);
char* container_id = (char*) "container_1498064906505_0001_01_000001";
char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
char allowed_devices[] = "243:2";
char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
excluded_devices,
(char*) "--allowed_devices",
allowed_devices,
(char*) "--container_id",
container_id };
/* Test case 1: block 2 devices */
clear_cgroups_parameters_invoked();
int rc = handle_devices_request(&mock_update_cgroups_parameters,
"devices", 7, argv);
ASSERT_EQ(0, rc) << "Should success.\n";
// Verify cgroups parameters
const char* expected_cgroups_argv[] = { "devices", "deny", container_id, "c 243:0 rwm",
"devices", "deny", container_id, "c 243:1 rwm"};
verify_param_updated_to_cgroups(8, expected_cgroups_argv);
/* Test case 2: block 2 devices but try allow devices not permitted by config*/
clear_cgroups_parameters_invoked();
// device plugin reported 0,1,2,3 totally. Allocated 1,2
// But c-e.cfg has device 1 denied.
char excluded_devices2[] = "c-243:0-rwm,c-243:3-rwm";
char allowed_devices2[] = "243:1,243:2";
char* argv1[] = { (char*) "--module-devices", (char*) "--excluded_devices",
excluded_devices2,
(char*) "--allowed_devices",
allowed_devices2,
(char*) "--container_id",
container_id };
rc = handle_devices_request(&mock_update_cgroups_parameters,
"devices", 7, argv1);
ASSERT_NE(0, rc) << "Should fail.\n";
clear_cgroups_parameters_invoked();
free_executor_configurations();
}
TEST_F(TestDevicesModule, test_illegal_cli_parameters) {
// Write config file.
const char *filename = TEST_ROOT "/test_illegal_cli_parameters.cfg";
write_and_load_devices_module_to_cfg(filename, 1);
char excluded_devices[] = "c-243:0-rwm,c-243:1-rwm";
char allowed_devices[] = "243:2";
// Illegal container id - 1
char* argv[] = { (char*) "--module-devices", (char*) "--excluded_devices",
excluded_devices,
(char*) "--allowed_devices",
allowed_devices,
(char*) "--container_id", (char*) "xxxx" };
int rc = handle_devices_request(&mock_update_cgroups_parameters,
"devices", 7, argv);
ASSERT_NE(0, rc) << "Should fail.\n";
// Illegal container id - 2
clear_cgroups_parameters_invoked();
char* argv_1[] = { (char*) "--module-devices", (char*) "--excluded_devices",
excluded_devices,
(char*) "--allowed_devices",
allowed_devices,
(char*) "--container_id", (char*) "container_1" };
rc = handle_devices_request(&mock_update_cgroups_parameters,
"devices", 7, argv_1);
ASSERT_NE(0, rc) << "Should fail.\n";
// Illegal container id - 3
clear_cgroups_parameters_invoked();
char* argv_2[] = { (char*) "--module-devices",
(char*) "--excluded_devices",
excluded_devices };
rc = handle_devices_request(&mock_update_cgroups_parameters,
"devices", 3, argv_2);
ASSERT_NE(0, rc) << "Should fail.\n";
clear_cgroups_parameters_invoked();
free_executor_configurations();
}
TEST_F(TestDevicesModule, test_devices_module_disabled) {
test_devices_module_enabled_disabled(0);
}
TEST_F(TestDevicesModule, test_devices_module_enabled) {
test_devices_module_enabled_disabled(1);
}
} // namespace ContainerExecutor

View File

@ -25,6 +25,7 @@ import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
@ -33,6 +34,8 @@ import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeS
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
@ -74,6 +77,10 @@ public class TestDeviceMappingManager {
private ExecutorService containerLauncher;
private Configuration conf;
private CGroupsHandler mockCGroupsHandler;
private PrivilegedOperationExecutor mockPrivilegedExecutor;
private Context mockCtx;
@Before
public void setup() throws Exception {
// setup resource-types.xml
@ -89,7 +96,7 @@ public class TestDeviceMappingManager {
isA(String.class),
isA(ArrayList.class));
dmm = new DeviceMappingManager(context);
int deviceCount = 600;
int deviceCount = 100;
TreeSet<Device> r = new TreeSet<>();
for (int i = 0; i < deviceCount; i++) {
r.add(Device.Builder.newInstance()
@ -117,6 +124,10 @@ public class TestDeviceMappingManager {
containerLauncher =
Executors.newFixedThreadPool(10);
mockCGroupsHandler = mock(CGroupsHandler.class);
mockPrivilegedExecutor = mock(PrivilegedOperationExecutor.class);
mockCtx = mock(NodeManager.NMContext.class);
when(mockCtx.getConf()).thenReturn(conf);
}
@After
@ -134,7 +145,7 @@ public class TestDeviceMappingManager {
@Test
public void testAllocation()
throws InterruptedException, ResourceHandlerException {
int totalContainerCount = 100;
int totalContainerCount = 10;
String resourceName1 = "cmpA.com/hdwA";
String resourceName2 = "cmp.com/cmp";
DeviceMappingManager dmmSpy = spy(dmm);
@ -158,11 +169,12 @@ public class TestDeviceMappingManager {
resourceName,
num, false);
containerSet.get(resourceName).put(c, num);
DevicePlugin myPlugin = new MyTestPlugin();
DevicePluginAdapter dpa = new DevicePluginAdapter(resourceName,
myPlugin, dmm);
DeviceResourceHandlerImpl dri = new DeviceResourceHandlerImpl(
resourceName,
new MyTestPlugin(), null,
dmmSpy, null, null);
resourceName, dpa,
dmmSpy, mockCGroupsHandler, mockPrivilegedExecutor, mockCtx);
Future<Integer> f = containerLauncher.submit(new MyContainerLaunch(
dri, c, i, false));
}
@ -173,12 +185,11 @@ public class TestDeviceMappingManager {
}
Long endTime = System.currentTimeMillis();
LOG.info("Each container allocation spends roughly: {} ms",
LOG.info("Each container preStart spends roughly: {} ms",
(endTime - startTime)/totalContainerCount);
// Ensure invocation times
verify(dmmSpy, times(totalContainerCount)).assignDevices(
anyString(), any(Container.class));
// Ensure used devices' count for each type is correct
int totalAllocatedCount = 0;
Map<Device, ContainerId> used1 =
@ -198,23 +209,15 @@ public class TestDeviceMappingManager {
for (Map.Entry<Container, Integer> entry :
containerSet.get(resourceName1).entrySet()) {
int containerWanted = entry.getValue();
int actualAllocated = 0;
for (ContainerId cid : used1.values()) {
if (cid.equals(entry.getKey().getContainerId())) {
actualAllocated++;
}
}
int actualAllocated = dmm.getAllocatedDevices(resourceName1,
entry.getKey().getContainerId()).size();
Assert.assertEquals(containerWanted, actualAllocated);
}
for (Map.Entry<Container, Integer> entry :
containerSet.get(resourceName2).entrySet()) {
int containerWanted = entry.getValue();
int actualAllocated = 0;
for (ContainerId cid : used2.values()) {
if (cid.equals(entry.getKey().getContainerId())) {
actualAllocated++;
}
}
int actualAllocated = dmm.getAllocatedDevices(resourceName2,
entry.getKey().getContainerId()).size();
Assert.assertEquals(containerWanted, actualAllocated);
}
}
@ -248,11 +251,12 @@ public class TestDeviceMappingManager {
resourceName,
num, false);
containerSet.get(resourceName).put(c, num);
DevicePlugin myPlugin = new MyTestPlugin();
DevicePluginAdapter dpa = new DevicePluginAdapter(resourceName,
myPlugin, dmm);
DeviceResourceHandlerImpl dri = new DeviceResourceHandlerImpl(
resourceName,
new MyTestPlugin(), null,
dmmSpy, null, null);
resourceName, dpa,
dmmSpy, mockCGroupsHandler, mockPrivilegedExecutor, mockCtx);
Future<Integer> f = containerLauncher.submit(new MyContainerLaunch(
dri, c, i, true));
}
@ -262,7 +266,6 @@ public class TestDeviceMappingManager {
LOG.info("Wait for the threads to finish");
}
// Ensure invocation times
verify(dmmSpy, times(totalContainerCount)).assignDevices(
anyString(), any(Container.class));

View File

@ -18,7 +18,6 @@
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
import org.apache.hadoop.service.ServiceOperations;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
@ -34,12 +33,20 @@ import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountDeviceSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountVolumeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.VolumeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.ResourcePluginManager;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
import org.apache.hadoop.yarn.server.nodemanager.recovery.NMMemoryStateStoreService;
@ -51,6 +58,7 @@ import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -60,15 +68,21 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
@ -89,7 +103,6 @@ public class TestDevicePluginAdapter {
private String tempResourceTypesFile;
private CGroupsHandler mockCGroupsHandler;
private PrivilegedOperationExecutor mockPrivilegedExecutor;
private NodeManager nm;
@Before
public void setup() throws Exception {
@ -110,13 +123,6 @@ public class TestDevicePluginAdapter {
if (dest.exists()) {
dest.delete();
}
if (nm != null) {
try {
ServiceOperations.stop(nm);
} catch (Throwable t) {
// ignore
}
}
}
@ -130,16 +136,14 @@ public class TestDevicePluginAdapter {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(dmm);
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
@ -150,14 +154,19 @@ public class TestDevicePluginAdapter {
spyPlugin, dmm);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
// Use mock shell when create resourceHandler
ShellWrapper mockShellWrapper = mock(ShellWrapper.class);
when(mockShellWrapper.existFile(anyString())).thenReturn(true);
when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
DeviceResourceHandlerImpl drhl = new DeviceResourceHandlerImpl(resourceName,
adapter, dmm, mockCGroupsHandler, mockPrivilegedExecutor, context,
mockShellWrapper);
adapter.setDeviceResourceHandler(drhl);
adapter.getDeviceResourceHandler().bootstrap(conf);
int size = dmm.getAvailableDevices(resourceName);
Assert.assertEquals(3, size);
// A container c1 requests 1 device
Container c1 = mockContainerWithDeviceRequest(0,
// Case 1. A container c1 requests 1 device
Container c1 = mockContainerWithDeviceRequest(1,
resourceName,
1, false);
// preStart
@ -169,19 +178,33 @@ public class TestDevicePluginAdapter {
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
Assert.assertEquals(1,
dmm.getAllocatedDevices(resourceName, c1.getContainerId()).size());
verify(mockShellWrapper, times(2)).getDeviceFileType(anyString());
// check device cgroup create operation
checkCgroupOperation(c1.getContainerId().toString(), 1,
"c-256:1-rwm,c-256:2-rwm", "256:0");
// postComplete
adapter.getDeviceResourceHandler().postComplete(getContainerId(0));
adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
Assert.assertEquals(3,
dmm.getAvailableDevices(resourceName));
Assert.assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
// A container c2 requests 3 device
Container c2 = mockContainerWithDeviceRequest(1,
// check cgroup delete operation
verify(mockCGroupsHandler).deleteCGroup(
CGroupsHandler.CGroupController.DEVICES,
c1.getContainerId().toString());
// Case 2. A container c2 requests 3 device
Container c2 = mockContainerWithDeviceRequest(2,
resourceName,
3, false);
reset(mockShellWrapper);
reset(mockCGroupsHandler);
reset(mockPrivilegedExecutor);
when(mockShellWrapper.existFile(anyString())).thenReturn(true);
when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
// preStart
adapter.getDeviceResourceHandler().preStart(c2);
// check book keeping
@ -191,19 +214,37 @@ public class TestDevicePluginAdapter {
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllocatedDevices(resourceName, c2.getContainerId()).size());
verify(mockShellWrapper, times(0)).getDeviceFileType(anyString());
// check device cgroup create operation
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
c2.getContainerId().toString());
// check device cgroup update operation
checkCgroupOperation(c2.getContainerId().toString(), 1,
null, "256:0,256:1,256:2");
// postComplete
adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
adapter.getDeviceResourceHandler().postComplete(getContainerId(2));
Assert.assertEquals(3,
dmm.getAvailableDevices(resourceName));
Assert.assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
// A container c3 request 0 device
Container c3 = mockContainerWithDeviceRequest(1,
// check cgroup delete operation
verify(mockCGroupsHandler).deleteCGroup(
CGroupsHandler.CGroupController.DEVICES,
c2.getContainerId().toString());
// Case 3. A container c3 request 0 device
Container c3 = mockContainerWithDeviceRequest(3,
resourceName,
0, false);
reset(mockShellWrapper);
reset(mockCGroupsHandler);
reset(mockPrivilegedExecutor);
when(mockShellWrapper.existFile(anyString())).thenReturn(true);
when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
// preStart
adapter.getDeviceResourceHandler().preStart(c3);
// check book keeping
@ -213,14 +254,57 @@ public class TestDevicePluginAdapter {
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
verify(mockShellWrapper, times(3)).getDeviceFileType(anyString());
// check device cgroup create operation
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
c3.getContainerId().toString());
// check device cgroup update operation
checkCgroupOperation(c3.getContainerId().toString(), 1,
"c-256:0-rwm,c-256:1-rwm,c-256:2-rwm", null);
// postComplete
adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
adapter.getDeviceResourceHandler().postComplete(getContainerId(3));
Assert.assertEquals(3,
dmm.getAvailableDevices(resourceName));
Assert.assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
Assert.assertEquals(0,
dmm.getAllocatedDevices(resourceName, c3.getContainerId()).size());
// check cgroup delete operation
verify(mockCGroupsHandler).deleteCGroup(
CGroupsHandler.CGroupController.DEVICES,
c3.getContainerId().toString());
}
private void checkCgroupOperation(String cId,
int invokeTimesOfPrivilegedExecutor,
String excludedParam, String allowedParam)
throws PrivilegedOperationException, ResourceHandlerException {
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
cId);
// check device cgroup update operation
ArgumentCaptor<PrivilegedOperation> args =
ArgumentCaptor.forClass(PrivilegedOperation.class);
verify(mockPrivilegedExecutor, times(invokeTimesOfPrivilegedExecutor))
.executePrivilegedOperation(args.capture(), eq(true));
Assert.assertEquals(PrivilegedOperation.OperationType.DEVICE,
args.getValue().getOperationType());
List<String> expectedArgs = new ArrayList<>();
expectedArgs.add(DeviceResourceHandlerImpl.CONTAINER_ID_CLI_OPTION);
expectedArgs.add(cId);
if (excludedParam != null && !excludedParam.isEmpty()) {
expectedArgs.add(DeviceResourceHandlerImpl.EXCLUDED_DEVICES_CLI_OPTION);
expectedArgs.add(excludedParam);
}
if (allowedParam != null && !allowedParam.isEmpty()) {
expectedArgs.add(DeviceResourceHandlerImpl.ALLOWED_DEVICES_CLI_OPTION);
expectedArgs.add(allowedParam);
}
Assert.assertArrayEquals(expectedArgs.toArray(),
args.getValue().getArguments().toArray());
}
@Test
@ -251,6 +335,7 @@ public class TestDevicePluginAdapter {
NMStateStoreService realStoreService = new NMMemoryStateStoreService();
NMStateStoreService storeService = spy(realStoreService);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
@ -395,6 +480,7 @@ public class TestDevicePluginAdapter {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService realStoreService = new NMMemoryStateStoreService();
NMStateStoreService storeService = spy(realStoreService);
when(context.getConf()).thenReturn(this.conf);
when(context.getNMStateStore()).thenReturn(storeService);
doThrow(new IOException("Exception ...")).when(storeService)
.storeAssignedResources(isA(Container.class),
@ -448,6 +534,7 @@ public class TestDevicePluginAdapter {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
@ -526,6 +613,7 @@ public class TestDevicePluginAdapter {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
@ -584,6 +672,206 @@ public class TestDevicePluginAdapter {
Assert.assertEquals(3, response.getTotalDevices().size());
}
/**
* Test a container run command update when using Docker runtime.
* And the device plugin it uses is like Nvidia Docker v1.
* */
@Test
public void testDeviceResourceDockerRuntimePlugin1() throws Exception {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
DeviceMappingManager spyDmm = spy(dmm);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(spyDmm);
// Init a plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, spyDmm);
adapter.initialize(context);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
// Case 1. A container request Docker runtime and 1 device
Container c1 = mockContainerWithDeviceRequest(1, resourceName, 1, true);
// generate spec based on v1
spyPlugin.setDevicePluginVersion("v1");
// preStart will do allocation
adapter.getDeviceResourceHandler().preStart(c1);
Set<Device> allocatedDevice = spyDmm.getAllocatedDevices(resourceName,
c1.getContainerId());
reset(spyDmm);
// c1 is requesting docker runtime.
// it will create parent cgroup but no cgroups update operation needed.
// check device cgroup create operation
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
c1.getContainerId().toString());
// ensure no cgroups update operation
verify(mockPrivilegedExecutor, times(0))
.executePrivilegedOperation(
any(PrivilegedOperation.class), anyBoolean());
DockerCommandPlugin dcp = adapter.getDockerCommandPluginInstance();
// When DockerLinuxContainerRuntime invoke the DockerCommandPluginInstance
// First to create volume
DockerVolumeCommand dvc = dcp.getCreateDockerVolumeCommand(c1);
// ensure that allocation is get once from device mapping manager
verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
// ensure that plugin's onDeviceAllocated is invoked
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DEFAULT);
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
Assert.assertEquals("nvidia-docker", dvc.getDriverName());
Assert.assertEquals("create", dvc.getSubCommand());
Assert.assertEquals("nvidia_driver_352.68", dvc.getVolumeName());
// then the DockerLinuxContainerRuntime will update docker run command
DockerRunCommand drc =
new DockerRunCommand(c1.getContainerId().toString(), "user",
"image/tensorflow");
// reset to avoid count times in above invocation
reset(spyPlugin);
reset(spyDmm);
// Second, update the run command.
dcp.updateDockerRunCommand(drc, c1);
// The spec is already generated in getCreateDockerVolumeCommand
// and there should be a cache hit for DeviceRuntime spec.
verify(spyPlugin, times(0)).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
// ensure that allocation is get from cache instead of device mapping
// manager
verify(spyDmm, times(0)).getAllocatedDevices(resourceName,
c1.getContainerId());
String runStr = drc.toString();
Assert.assertTrue(
runStr.contains("nvidia_driver_352.68:/usr/local/nvidia:ro"));
Assert.assertTrue(runStr.contains("/dev/hdwA0:/dev/hdwA0"));
// Third, cleanup in getCleanupDockerVolumesCommand
dcp.getCleanupDockerVolumesCommand(c1);
// Ensure device plugin's onDeviceReleased is invoked
verify(spyPlugin).onDevicesReleased(allocatedDevice);
// If we run the c1 again. No cache will be used for allocation and spec
dcp.getCreateDockerVolumeCommand(c1);
verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
}
/**
* Test a container run command update when using Docker runtime.
* And the device plugin it uses is like Nvidia Docker v2.
* */
@Test
public void testDeviceResourceDockerRuntimePlugin2() throws Exception {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
DeviceMappingManager spyDmm = spy(dmm);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(spyDmm);
// Init a plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, spyDmm);
adapter.initialize(context);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
// Case 1. A container request Docker runtime and 1 device
Container c1 = mockContainerWithDeviceRequest(1, resourceName, 2, true);
// generate spec based on v2
spyPlugin.setDevicePluginVersion("v2");
// preStart will do allocation
adapter.getDeviceResourceHandler().preStart(c1);
Set<Device> allocatedDevice = spyDmm.getAllocatedDevices(resourceName,
c1.getContainerId());
reset(spyDmm);
// c1 is requesting docker runtime.
// it will create parent cgroup but no cgroups update operation needed.
// check device cgroup create operation
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
c1.getContainerId().toString());
// ensure no cgroups update operation
verify(mockPrivilegedExecutor, times(0))
.executePrivilegedOperation(
any(PrivilegedOperation.class), anyBoolean());
DockerCommandPlugin dcp = adapter.getDockerCommandPluginInstance();
// When DockerLinuxContainerRuntime invoke the DockerCommandPluginInstance
// First to create volume
DockerVolumeCommand dvc = dcp.getCreateDockerVolumeCommand(c1);
// ensure that allocation is get once from device mapping manager
verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
// ensure that plugin's onDeviceAllocated is invoked
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DEFAULT);
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
// No volume creation request
Assert.assertNull(dvc);
// then the DockerLinuxContainerRuntime will update docker run command
DockerRunCommand drc =
new DockerRunCommand(c1.getContainerId().toString(), "user",
"image/tensorflow");
// reset to avoid count times in above invocation
reset(spyPlugin);
reset(spyDmm);
// Second, update the run command.
dcp.updateDockerRunCommand(drc, c1);
// The spec is already generated in getCreateDockerVolumeCommand
// and there should be a cache hit for DeviceRuntime spec.
verify(spyPlugin, times(0)).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
// ensure that allocation is get once from device mapping manager
verify(spyDmm, times(0)).getAllocatedDevices(resourceName,
c1.getContainerId());
Assert.assertEquals("0,1", drc.getEnv().get("NVIDIA_VISIBLE_DEVICES"));
Assert.assertTrue(drc.toString().contains("runtime=nvidia"));
// Third, cleanup in getCleanupDockerVolumesCommand
dcp.getCleanupDockerVolumesCommand(c1);
// Ensure device plugin's onDeviceReleased is invoked
verify(spyPlugin).onDevicesReleased(allocatedDevice);
// If we run the c1 again. No cache will be used for allocation and spec
dcp.getCreateDockerVolumeCommand(c1);
verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
}
private static ContainerId getContainerId(int id) {
return ContainerId.newContainerId(ApplicationAttemptId
.newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
@ -591,6 +879,15 @@ public class TestDevicePluginAdapter {
private class MyPlugin implements DevicePlugin, DevicePluginScheduler {
private final static String RESOURCE_NAME = "cmpA.com/hdwA";
// v1 means the vendor uses the similar way of Nvidia Docker v1
// v2 means the vendor user the similar way of Nvidia Docker v2
private String devicePluginVersion = "v2";
public void setDevicePluginVersion(String version) {
devicePluginVersion = version;
}
@Override
public DeviceRegisterRequest getRegisterRequestInfo() {
return DeviceRegisterRequest.Builder.newInstance()
@ -613,7 +910,7 @@ public class TestDevicePluginAdapter {
.setId(1)
.setDevPath("/dev/hdwA1")
.setMajorNumber(256)
.setMinorNumber(0)
.setMinorNumber(1)
.setBusID("0000:80:01.0")
.setHealthy(true)
.build());
@ -621,7 +918,7 @@ public class TestDevicePluginAdapter {
.setId(2)
.setDevPath("/dev/hdwA2")
.setMajorNumber(256)
.setMinorNumber(0)
.setMinorNumber(2)
.setBusID("0000:80:02.0")
.setHealthy(true)
.build());
@ -631,12 +928,69 @@ public class TestDevicePluginAdapter {
@Override
public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices,
YarnRuntimeType yarnRuntime) throws Exception {
if (yarnRuntime == YarnRuntimeType.RUNTIME_DEFAULT) {
return null;
}
if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {
return generateSpec(devicePluginVersion, allocatedDevices);
}
return null;
}
private DeviceRuntimeSpec generateSpec(String version,
Set<Device> allocatedDevices) {
DeviceRuntimeSpec.Builder builder =
DeviceRuntimeSpec.Builder.newInstance();
if (version.equals("v1")) {
// Nvidia v1 examples like below. These info is get from Nvidia v1
// RESTful.
// --device=/dev/nvidiactl --device=/dev/nvidia-uvm
// --device=/dev/nvidia0
// --volume-driver=nvidia-docker
// --volume=nvidia_driver_352.68:/usr/local/nvidia:ro
String volumeDriverName = "nvidia-docker";
String volumeToBeCreated = "nvidia_driver_352.68";
String volumePathInContainer = "/usr/local/nvidia";
// describe volumes to be created and mounted
builder.addVolumeSpec(
VolumeSpec.Builder.newInstance()
.setVolumeDriver(volumeDriverName)
.setVolumeName(volumeToBeCreated)
.setVolumeOperation(VolumeSpec.CREATE).build())
.addMountVolumeSpec(
MountVolumeSpec.Builder.newInstance()
.setHostPath(volumeToBeCreated)
.setMountPath(volumePathInContainer)
.setReadOnly(true).build());
// describe devices to be mounted
for (Device device : allocatedDevices) {
builder.addMountDeviceSpec(
MountDeviceSpec.Builder.newInstance()
.setDevicePathInHost(device.getDevPath())
.setDevicePathInContainer(device.getDevPath())
.setDevicePermission(MountDeviceSpec.RW).build());
}
}
if (version.equals("v2")) {
String nvidiaRuntime = "nvidia";
String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES";
StringBuffer gpuMinorNumbersSB = new StringBuffer();
for (Device device : allocatedDevices) {
gpuMinorNumbersSB.append(device.getMinorNumber() + ",");
}
String minorNumbers = gpuMinorNumbersSB.toString();
// set runtime and environment variable is enough for
// plugin like Nvidia Docker v2
builder.addEnv(nvidiaVisibleDevices,
minorNumbers.substring(0, minorNumbers.length() - 1))
.setContainerRuntime(nvidiaRuntime);
}
return builder.build();
}
@Override
public void onDevicesReleased(Set<Device> releasedDevices) {
// nothing to do
}
@Override

View File

@ -0,0 +1,108 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.nvidia.com;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia.NvidiaGPUPluginForRuntimeV2;
import org.junit.Assert;
import org.junit.Test;
import java.util.Set;
import java.util.TreeSet;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Test case for Nvidia GPU device plugin.
* */
public class TestNvidiaGpuPlugin {
@Test
public void testGetNvidiaDevices() throws Exception {
NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor mockShell =
mock(NvidiaGPUPluginForRuntimeV2.NvidiaCommandExecutor.class);
String deviceInfoShellOutput =
"0, 00000000:04:00.0\n" +
"1, 00000000:82:00.0";
String majorMinorNumber0 = "c3:0";
String majorMinorNumber1 = "c3:1";
when(mockShell.getDeviceInfo()).thenReturn(deviceInfoShellOutput);
when(mockShell.getMajorMinorInfo("nvidia0"))
.thenReturn(majorMinorNumber0);
when(mockShell.getMajorMinorInfo("nvidia1"))
.thenReturn(majorMinorNumber1);
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
plugin.setShellExecutor(mockShell);
plugin.setPathOfGpuBinary("/fake/nvidia-smi");
Set<Device> expectedDevices = new TreeSet<>();
expectedDevices.add(Device.Builder.newInstance()
.setId(0).setHealthy(true)
.setBusID("00000000:04:00.0")
.setDevPath("/dev/nvidia0")
.setMajorNumber(195)
.setMinorNumber(0).build());
expectedDevices.add(Device.Builder.newInstance()
.setId(1).setHealthy(true)
.setBusID("00000000:82:00.0")
.setDevPath("/dev/nvidia1")
.setMajorNumber(195)
.setMinorNumber(1).build());
Set<Device> devices = plugin.getDevices();
Assert.assertEquals(expectedDevices, devices);
}
@Test
public void testOnDeviceAllocated() throws Exception {
NvidiaGPUPluginForRuntimeV2 plugin = new NvidiaGPUPluginForRuntimeV2();
Set<Device> allocatedDevices = new TreeSet<>();
DeviceRuntimeSpec spec = plugin.onDevicesAllocated(allocatedDevices,
YarnRuntimeType.RUNTIME_DEFAULT);
Assert.assertNull(spec);
// allocate one device
allocatedDevices.add(Device.Builder.newInstance()
.setId(0).setHealthy(true)
.setBusID("00000000:04:00.0")
.setDevPath("/dev/nvidia0")
.setMajorNumber(195)
.setMinorNumber(0).build());
spec = plugin.onDevicesAllocated(allocatedDevices,
YarnRuntimeType.RUNTIME_DOCKER);
Assert.assertEquals("nvidia", spec.getContainerRuntime());
Assert.assertEquals("0", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
// two device allowed
allocatedDevices.add(Device.Builder.newInstance()
.setId(0).setHealthy(true)
.setBusID("00000000:82:00.0")
.setDevPath("/dev/nvidia1")
.setMajorNumber(195)
.setMinorNumber(1).build());
spec = plugin.onDevicesAllocated(allocatedDevices,
YarnRuntimeType.RUNTIME_DOCKER);
Assert.assertEquals("nvidia", spec.getContainerRuntime());
Assert.assertEquals("0,1", spec.getEnvs().get("NVIDIA_VISIBLE_DEVICES"));
}
}