diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/fpga/TestFpgaResourceHandlerImpl.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/fpga/TestFpgaResourceHandlerImpl.java
new file mode 100644
index 00000000000..77f992cd32d
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/fpga/TestFpgaResourceHandlerImpl.java
@@ -0,0 +1,611 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.fpga;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyList;
+import static org.mockito.ArgumentMatchers.anyMap;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.util.StringUtils;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.api.records.ResourceInformation;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.fpga.FpgaResourceAllocator.FpgaDevice;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.localizer.ResourceSet;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.fpga.FpgaDiscoverer;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.fpga.IntelFpgaOpenclPlugin;
+import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
+import org.apache.hadoop.yarn.util.resource.CustomResourceTypesConfigurationProvider;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import com.google.common.io.Files;
+import com.google.common.io.FileWriteMode;
+
+public class TestFpgaResourceHandlerImpl {
+ @Rule
+ public ExpectedException expected = ExpectedException.none();
+
+ private static final String HASHABLE_STRING = "abcdef";
+ private static final String EXPECTED_HASH =
+ "bef57ec7f53a6d40beb640a780a639c83bc29ac8a9816f1fc6c5c6dcd93c4721";
+
+ private Context mockContext;
+ private FpgaResourceHandlerImpl fpgaResourceHandler;
+ private Configuration configuration;
+ private CGroupsHandler mockCGroupsHandler;
+ private PrivilegedOperationExecutor mockPrivilegedExecutor;
+ private NMStateStoreService mockNMStateStore;
+ private ConcurrentHashMap runningContainersMap;
+ private IntelFpgaOpenclPlugin mockVendorPlugin;
+ private List deviceList;
+ private FpgaDiscoverer fpgaDiscoverer;
+ private static final String vendorType = "IntelOpenCL";
+ private File dummyAocx;
+
+ private String getTestParentFolder() {
+ File f = new File("target/temp/" +
+ TestFpgaResourceHandlerImpl.class.getName());
+ return f.getAbsolutePath();
+ }
+
+ @Before
+ public void setup() throws IOException, YarnException {
+ CustomResourceTypesConfigurationProvider.
+ initResourceTypes(ResourceInformation.FPGA_URI);
+ configuration = new YarnConfiguration();
+
+ mockCGroupsHandler = mock(CGroupsHandler.class);
+ mockPrivilegedExecutor = mock(PrivilegedOperationExecutor.class);
+ mockNMStateStore = mock(NMStateStoreService.class);
+ mockContext = mock(Context.class);
+ // Assumed devices parsed from output
+ deviceList = new ArrayList<>();
+ for (int i = 0; i < 5; i++) {
+ deviceList.add(new FpgaDevice(vendorType, 247, i, "acl" + i));
+ }
+ String aocxPath = getTestParentFolder() + "/test.aocx";
+ mockVendorPlugin = mockPlugin(vendorType, deviceList, aocxPath);
+ fpgaDiscoverer = new FpgaDiscoverer();
+ fpgaDiscoverer.setResourceHanderPlugin(mockVendorPlugin);
+ fpgaDiscoverer.initialize(configuration);
+ when(mockContext.getNMStateStore()).thenReturn(mockNMStateStore);
+ runningContainersMap = new ConcurrentHashMap<>();
+ when(mockContext.getContainers()).thenReturn(runningContainersMap);
+
+ fpgaResourceHandler = new FpgaResourceHandlerImpl(mockContext,
+ mockCGroupsHandler, mockPrivilegedExecutor, mockVendorPlugin,
+ fpgaDiscoverer);
+
+ dummyAocx = new File(aocxPath);
+ Files.createParentDirs(dummyAocx);
+ Files.touch(dummyAocx);
+ Files.asCharSink(dummyAocx, StandardCharsets.UTF_8, FileWriteMode.APPEND)
+ .write(HASHABLE_STRING);
+ }
+
+ @After
+ public void teardown() {
+ if (dummyAocx != null) {
+ dummyAocx.delete();
+ }
+ }
+
+ @Test
+ public void testBootstrap() throws ResourceHandlerException {
+ // Case 1. auto
+ String allowed = "auto";
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, allowed);
+ fpgaResourceHandler.bootstrap(configuration);
+ // initPlugin() was also called in setup()
+ verify(mockVendorPlugin, times(2)).initPlugin(configuration);
+ verify(mockCGroupsHandler, times(1)).initializeCGroupController(
+ CGroupsHandler.CGroupController.DEVICES);
+ Assert.assertEquals(5, fpgaResourceHandler.getFpgaAllocator()
+ .getAvailableFpgaCount());
+ Assert.assertEquals(5, fpgaResourceHandler.getFpgaAllocator()
+ .getAllowedFpga().size());
+ // Case 2. subset of devices
+ fpgaResourceHandler = new FpgaResourceHandlerImpl(mockContext,
+ mockCGroupsHandler, mockPrivilegedExecutor, mockVendorPlugin,
+ fpgaDiscoverer);
+ allowed = "0,1,2";
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, allowed);
+ fpgaResourceHandler.bootstrap(configuration);
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getAllowedFpga().size());
+ List allowedDevices =
+ fpgaResourceHandler.getFpgaAllocator().getAllowedFpga();
+ for (String s : allowed.split(",")) {
+ boolean check = false;
+ for (FpgaDevice device : allowedDevices) {
+ if (String.valueOf(device.getMinor()).equals(s)) {
+ check = true;
+ }
+ }
+ Assert.assertTrue("Minor:" + s +" found", check);
+ }
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+
+ // Case 3. User configuration contains invalid minor device number
+ fpgaResourceHandler = new FpgaResourceHandlerImpl(mockContext,
+ mockCGroupsHandler, mockPrivilegedExecutor, mockVendorPlugin,
+ fpgaDiscoverer);
+ allowed = "0,1,7";
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, allowed);
+ fpgaResourceHandler.bootstrap(configuration);
+ Assert.assertEquals(2,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+ Assert.assertEquals(2,
+ fpgaResourceHandler.getFpgaAllocator().getAllowedFpga().size());
+ }
+
+ @Test
+ public void testBootstrapWithInvalidUserConfiguration()
+ throws ResourceHandlerException {
+ // User configuration contains invalid minor device number
+ String allowed = "0,1,7";
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, allowed);
+ fpgaResourceHandler.bootstrap(configuration);
+ Assert.assertEquals(2,
+ fpgaResourceHandler.getFpgaAllocator().getAllowedFpga().size());
+ Assert.assertEquals(2,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+
+ String[] invalidAllowedStrings = {"a,1,2,", "a,1,2", "0,1,2,#", "a", "1,"};
+ for (String s : invalidAllowedStrings) {
+ boolean invalidConfiguration = false;
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, s);
+ try {
+ fpgaResourceHandler.bootstrap(configuration);
+ } catch (ResourceHandlerException e) {
+ invalidConfiguration = true;
+ }
+ Assert.assertTrue(invalidConfiguration);
+ }
+
+ String[] allowedStrings = {"1,2", "1"};
+ for (String s : allowedStrings) {
+ boolean invalidConfiguration = false;
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, s);
+ try {
+ fpgaResourceHandler.bootstrap(configuration);
+ } catch (ResourceHandlerException e) {
+ invalidConfiguration = true;
+ }
+ Assert.assertFalse(invalidConfiguration);
+ }
+ }
+
+ @Test
+ public void testBootStrapWithEmptyUserConfiguration()
+ throws ResourceHandlerException {
+ // User configuration contains invalid minor device number
+ String allowed = "";
+ boolean invalidConfiguration = false;
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, allowed);
+ try {
+ fpgaResourceHandler.bootstrap(configuration);
+ } catch (ResourceHandlerException e) {
+ invalidConfiguration = true;
+ }
+ Assert.assertTrue(invalidConfiguration);
+ }
+
+ @Test
+ public void testAllocationWithPreference()
+ throws ResourceHandlerException, PrivilegedOperationException {
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, "0,1,2");
+ fpgaResourceHandler.bootstrap(configuration);
+ // Case 1. The id-0 container request 1 FPGA of IntelOpenCL type and GEMM IP
+ fpgaResourceHandler.preStart(mockContainer(0, 1, "GEMM"));
+ Assert.assertEquals(1, fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ verifyDeniedDevices(getContainerId(0), Arrays.asList(1, 2));
+ List list = fpgaResourceHandler.getFpgaAllocator()
+ .getUsedFpga().get(getContainerId(0).toString());
+ for (FpgaDevice device : list) {
+ Assert.assertEquals("IP should be updated to GEMM", "GEMM", device.getIPID());
+ }
+ // Case 2. The id-1 container request 3 FPGA of IntelOpenCL and GEMM IP. this should fail
+ boolean flag = false;
+ try {
+ fpgaResourceHandler.preStart(mockContainer(1, 3, "GZIP"));
+ } catch (ResourceHandlerException e) {
+ flag = true;
+ }
+ Assert.assertTrue(flag);
+ // Case 3. Release the id-0 container
+ fpgaResourceHandler.postComplete(getContainerId(0));
+ Assert.assertEquals(0,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+ // Now we have enough devices, re-allocate for the id-1 container
+ fpgaResourceHandler.preStart(mockContainer(1, 3, "GEMM"));
+ // Id-1 container should have 0 denied devices
+ verifyDeniedDevices(getContainerId(1), new ArrayList<>());
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(0,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+ // Release container id-1
+ fpgaResourceHandler.postComplete(getContainerId(1));
+ Assert.assertEquals(0,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+ // Case 4. Now all 3 devices should have IPID GEMM
+ // Try container id-2 and id-3
+ fpgaResourceHandler.preStart(mockContainer(2, 1, "GZIP"));
+ fpgaResourceHandler.postComplete(getContainerId(2));
+ fpgaResourceHandler.preStart(mockContainer(3, 2, "GEMM"));
+
+ // IPID should be GEMM for id-3 container
+ list = fpgaResourceHandler.getFpgaAllocator()
+ .getUsedFpga().get(getContainerId(3).toString());
+ for (FpgaDevice device : list) {
+ Assert.assertEquals("IPID should be GEMM", "GEMM", device.getIPID());
+ }
+ Assert.assertEquals(2,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(1,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+ fpgaResourceHandler.postComplete(getContainerId(3));
+ Assert.assertEquals(0,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+
+ // Case 5. id-4 request 0 FPGA device
+ fpgaResourceHandler.preStart(mockContainer(4, 0, ""));
+ // Deny all devices for id-4
+ verifyDeniedDevices(getContainerId(4), Arrays.asList(0, 1, 2));
+ Assert.assertEquals(0,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+
+ // Case 6. id-5 with invalid FPGA device
+ try {
+ fpgaResourceHandler.preStart(mockContainer(5, -2, ""));
+ } catch (ResourceHandlerException e) {
+ Assert.assertTrue(true);
+ }
+ }
+
+ @Test
+ public void testsAllocationWithExistingIPIDDevices()
+ throws ResourceHandlerException, PrivilegedOperationException,
+ IOException {
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, "0,1,2");
+ fpgaResourceHandler.bootstrap(configuration);
+ // The id-0 container request 3 FPGA of IntelOpenCL type and GEMM IP
+ fpgaResourceHandler.preStart(mockContainer(0, 3, "GEMM"));
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ List list =
+ fpgaResourceHandler
+ .getFpgaAllocator()
+ .getUsedFpga()
+ .get(getContainerId(0).toString());
+ fpgaResourceHandler.postComplete(getContainerId(0));
+ for (FpgaDevice device : list) {
+ Assert.assertEquals("IP should be updated to GEMM", "GEMM",
+ device.getIPID());
+ }
+
+ // Case 1. id-1 container request preStart, with no plugin.configureIP called
+ fpgaResourceHandler.preStart(mockContainer(1, 1, "GEMM"));
+ fpgaResourceHandler.preStart(mockContainer(2, 1, "GEMM"));
+ // we should have 3 times due to id-1 skip 1 invocation
+ verify(mockVendorPlugin, times(3)).configureIP(anyString(),
+ any(FpgaDevice.class));
+ fpgaResourceHandler.postComplete(getContainerId(1));
+ fpgaResourceHandler.postComplete(getContainerId(2));
+
+ // Case 2. id-2 container request preStart, with 1 plugin.configureIP called
+ // Add some characters to the dummy file to have its hash changed
+ Files.asCharSink(dummyAocx, StandardCharsets.UTF_8, FileWriteMode.APPEND)
+ .write("12345");
+ fpgaResourceHandler.preStart(mockContainer(1, 1, "GZIP"));
+ // we should have 4 times invocation
+ verify(mockVendorPlugin, times(4)).configureIP(anyString(),
+ any(FpgaDevice.class));
+ }
+
+ @Test
+ public void testAllocationWithZeroDevices()
+ throws ResourceHandlerException, PrivilegedOperationException {
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, "0,1,2");
+ fpgaResourceHandler.bootstrap(configuration);
+ // The id-0 container request 0 FPGA
+ fpgaResourceHandler.preStart(mockContainer(0, 0, null));
+ verifyDeniedDevices(getContainerId(0), Arrays.asList(0, 1, 2));
+ verify(mockVendorPlugin, times(0)).retrieveIPfilePath(anyString(),
+ anyString(), anyMap());
+ verify(mockVendorPlugin, times(0)).configureIP(anyString(),
+ any(FpgaDevice.class));
+ }
+
+ @Test
+ public void testStateStore()
+ throws ResourceHandlerException, IOException {
+ // Case 1. store 3 devices
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, "0,1,2");
+ fpgaResourceHandler.bootstrap(configuration);
+ Container container0 = mockContainer(0, 3, "GEMM");
+ fpgaResourceHandler.preStart(container0);
+ List assigned =
+ fpgaResourceHandler
+ .getFpgaAllocator()
+ .getUsedFpga()
+ .get(getContainerId(0).toString());
+ verify(mockNMStateStore).storeAssignedResources(container0,
+ ResourceInformation.FPGA_URI,
+ new ArrayList<>(assigned));
+ fpgaResourceHandler.postComplete(getContainerId(0));
+ // Case 2. ask 0, no store api called
+ Container container1 = mockContainer(1, 0, "");
+ fpgaResourceHandler.preStart(container1);
+ verify(mockNMStateStore, never()).storeAssignedResources(
+ eq(container1), eq(ResourceInformation.FPGA_URI), anyList());
+ }
+
+ @Test
+ public void testReacquireContainer() throws ResourceHandlerException {
+ Container c0 = mockContainer(0, 2, "GEMM");
+ List assigned = new ArrayList<>();
+ assigned.add(new FpgaDevice(
+ vendorType, 247, 0, "acl0"));
+ assigned.add(new FpgaDevice(
+ vendorType, 247, 1, "acl1"));
+ // Mock we've stored the c0 states
+ mockStateStoreForContainer(c0, assigned);
+ // NM start
+ configuration.set(YarnConfiguration.NM_FPGA_ALLOWED_DEVICES, "0,1,2");
+ fpgaResourceHandler.bootstrap(configuration);
+ Assert.assertEquals(0,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+ // Case 1. try recover state for id-0 container
+ fpgaResourceHandler.reacquireContainer(getContainerId(0));
+ // minor number matches
+ List used =
+ fpgaResourceHandler.getFpgaAllocator().
+ getUsedFpga().get(getContainerId(0).toString());
+ int count = 0;
+ for (FpgaDevice device : used) {
+ if (device.getMinor() == 0){
+ count++;
+ }
+ if (device.getMinor() == 1) {
+ count++;
+ }
+ }
+ Assert.assertEquals("Unexpected used minor number in allocator",2, count);
+ List available =
+ fpgaResourceHandler
+ .getFpgaAllocator()
+ .getAvailableFpga()
+ .get(vendorType);
+ count = 0;
+ for (FpgaDevice device : available) {
+ if (device.getMinor() == 2) {
+ count++;
+ }
+ }
+ Assert.assertEquals("Unexpected available minor number in allocator",
+ 1, count);
+
+
+ // Case 2. Recover a not allowed device with minor number 5
+ Container c1 = mockContainer(1, 1, "GEMM");
+ assigned = new ArrayList<>();
+ assigned.add(new FpgaDevice(
+ vendorType, 247, 5, "acl0"));
+ // Mock we've stored the c1 states
+ mockStateStoreForContainer(c1, assigned);
+ boolean flag = false;
+ try {
+ fpgaResourceHandler.reacquireContainer(getContainerId(1));
+ } catch (ResourceHandlerException e) {
+ flag = true;
+ }
+ Assert.assertTrue(flag);
+ Assert.assertEquals(2,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(1,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+
+ // Case 3. recover a already used device by other container
+ Container c2 = mockContainer(2, 1, "GEMM");
+ assigned = new ArrayList<>();
+ assigned.add(new FpgaDevice(
+ vendorType, 247, 1, "acl0"));
+ // Mock we've stored the c2 states
+ mockStateStoreForContainer(c2, assigned);
+ flag = false;
+ try {
+ fpgaResourceHandler.reacquireContainer(getContainerId(2));
+ } catch (ResourceHandlerException e) {
+ flag = true;
+ }
+ Assert.assertTrue(flag);
+ Assert.assertEquals(2,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(1,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+
+ // Case 4. recover a normal container c3 with remaining minor device number 2
+ Container c3 = mockContainer(3, 1, "GEMM");
+ assigned = new ArrayList<>();
+ assigned.add(new FpgaDevice(
+ vendorType, 247, 2, "acl2"));
+ // Mock we've stored the c2 states
+ mockStateStoreForContainer(c3, assigned);
+ fpgaResourceHandler.reacquireContainer(getContainerId(3));
+ Assert.assertEquals(3,
+ fpgaResourceHandler.getFpgaAllocator().getUsedFpgaCount());
+ Assert.assertEquals(0,
+ fpgaResourceHandler.getFpgaAllocator().getAvailableFpgaCount());
+ }
+
+ @Test
+ public void testSha256CalculationFails() throws ResourceHandlerException {
+ expected.expect(ResourceHandlerException.class);
+ expected.expectMessage("Could not calculate SHA-256");
+
+ dummyAocx.delete();
+ fpgaResourceHandler.preStart(mockContainer(0, 1, "GEMM"));
+ }
+
+ @Test
+ public void testSha256CalculationSucceeds()
+ throws IOException, ResourceHandlerException {
+ mockVendorPlugin =
+ mockPlugin(vendorType, deviceList, dummyAocx.getAbsolutePath());
+ fpgaResourceHandler = new FpgaResourceHandlerImpl(mockContext,
+ mockCGroupsHandler, mockPrivilegedExecutor, mockVendorPlugin,
+ fpgaDiscoverer);
+
+ fpgaResourceHandler.bootstrap(configuration);
+ fpgaResourceHandler.preStart(mockContainer(0, 1, "GEMM"));
+
+ // IP file is assigned to the first device
+ List devices =
+ fpgaResourceHandler.getFpgaAllocator().getAllowedFpga();
+ FpgaDevice device = devices.get(0);
+ assertEquals("Hash value", EXPECTED_HASH, device.getAocxHash());
+ }
+
+ private void verifyDeniedDevices(ContainerId containerId,
+ List deniedDevices)
+ throws ResourceHandlerException, PrivilegedOperationException {
+ verify(mockCGroupsHandler, atLeastOnce()).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES, containerId.toString());
+
+ if (null != deniedDevices && !deniedDevices.isEmpty()) {
+ verify(mockPrivilegedExecutor, times(1)).executePrivilegedOperation(
+ new PrivilegedOperation(PrivilegedOperation.OperationType.FPGA, Arrays
+ .asList(FpgaResourceHandlerImpl.CONTAINER_ID_CLI_OPTION,
+ containerId.toString(),
+ FpgaResourceHandlerImpl.EXCLUDED_FPGAS_CLI_OPTION,
+ StringUtils.join(",", deniedDevices))), true);
+ } else if (deniedDevices.isEmpty()) {
+ verify(mockPrivilegedExecutor, times(1)).executePrivilegedOperation(
+ new PrivilegedOperation(PrivilegedOperation.OperationType.FPGA, Arrays
+ .asList(FpgaResourceHandlerImpl.CONTAINER_ID_CLI_OPTION,
+ containerId.toString())), true);
+ }
+ }
+
+ private static IntelFpgaOpenclPlugin mockPlugin(String type,
+ List list, String aocxPath) {
+ IntelFpgaOpenclPlugin plugin = mock(IntelFpgaOpenclPlugin.class);
+ when(plugin.initPlugin(any())).thenReturn(true);
+ when(plugin.getFpgaType()).thenReturn(type);
+ when(plugin.retrieveIPfilePath(anyString(),
+ anyString(), anyMap())).thenReturn(aocxPath);
+ when(plugin.configureIP(anyString(), any()))
+ .thenReturn(true);
+ when(plugin.discover(anyInt())).thenReturn(list);
+ return plugin;
+ }
+
+ private static Container mockContainer(int id, int numFpga, String IPID) {
+ Container c = mock(Container.class);
+
+ Resource res = Resource.newInstance(1024, 1);
+ ResourceMappings resMapping = new ResourceMappings();
+ res.setResourceValue(ResourceInformation.FPGA_URI, numFpga);
+ when(c.getResource()).thenReturn(res);
+ when(c.getResourceMappings()).thenReturn(resMapping);
+
+ when(c.getContainerId()).thenReturn(getContainerId(id));
+
+ ContainerLaunchContext clc = mock(ContainerLaunchContext.class);
+ Map envs = new HashMap<>();
+ if (numFpga > 0) {
+ envs.put("REQUESTED_FPGA_IP_ID", IPID);
+ }
+ when(c.getLaunchContext()).thenReturn(clc);
+ when(clc.getEnvironment()).thenReturn(envs);
+ when(c.getWorkDir()).thenReturn("/tmp");
+ ResourceSet resourceSet = new ResourceSet();
+ when(c.getResourceSet()).thenReturn(resourceSet);
+
+ return c;
+ }
+
+ private void mockStateStoreForContainer(Container container,
+ List assigned) {
+ ResourceMappings rmap = new ResourceMappings();
+ ResourceMappings.AssignedResources ar =
+ new ResourceMappings.AssignedResources();
+ ar.updateAssignedResources(new ArrayList<>(assigned));
+ rmap.addAssignedResources(ResourceInformation.FPGA_URI, ar);
+ when(container.getResourceMappings()).thenReturn(rmap);
+ runningContainersMap.put(container.getContainerId(), container);
+ }
+
+ private static ContainerId getContainerId(int id) {
+ return ContainerId.newContainerId(ApplicationAttemptId
+ .newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
+ }
+}
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/TestGpuResourceHandlerImpl.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/TestGpuResourceHandlerImpl.java
new file mode 100644
index 00000000000..4b50454ea23
--- /dev/null
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/TestGpuResourceHandlerImpl.java
@@ -0,0 +1,525 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.util.StringUtils;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.api.records.ResourceInformation;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDevice;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDiscoverer;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
+import org.apache.hadoop.yarn.server.nodemanager.recovery.NMNullStateStoreService;
+import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
+import org.apache.hadoop.yarn.util.resource.CustomResourceTypesConfigurationProvider;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyList;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class TestGpuResourceHandlerImpl {
+ private CGroupsHandler mockCGroupsHandler;
+ private PrivilegedOperationExecutor mockPrivilegedExecutor;
+ private GpuResourceHandlerImpl gpuResourceHandler;
+ private NMStateStoreService mockNMStateStore;
+ private ConcurrentHashMap runningContainersMap;
+ private GpuDiscoverer gpuDiscoverer;
+ private File testDataDirectory;
+
+ public void createTestDataDirectory() throws IOException {
+ String testDirectoryPath = getTestParentDirectory();
+ testDataDirectory = new File(testDirectoryPath);
+ FileUtils.deleteDirectory(testDataDirectory);
+ testDataDirectory.mkdirs();
+ }
+
+ private String getTestParentDirectory() {
+ File f = new File("target/temp/" + TestGpuResourceHandlerImpl.class.getName());
+ return f.getAbsolutePath();
+ }
+
+ private void touchFile(File f) throws IOException {
+ new FileOutputStream(f).close();
+ }
+
+ private Configuration createDefaultConfig() throws IOException {
+ Configuration conf = new YarnConfiguration();
+ File fakeBinary = setupFakeGpuDiscoveryBinary();
+ conf.set(YarnConfiguration.NM_GPU_PATH_TO_EXEC,
+ fakeBinary.getAbsolutePath());
+ return conf;
+ }
+
+ private File setupFakeGpuDiscoveryBinary() throws IOException {
+ File fakeBinary = new File(getTestParentDirectory() + "/fake-nvidia-smi");
+ touchFile(fakeBinary);
+ return fakeBinary;
+ }
+
+ @Before
+ public void setup() throws IOException {
+ createTestDataDirectory();
+
+ CustomResourceTypesConfigurationProvider.
+ initResourceTypes(ResourceInformation.GPU_URI);
+
+ mockCGroupsHandler = mock(CGroupsHandler.class);
+ mockPrivilegedExecutor = mock(PrivilegedOperationExecutor.class);
+ mockNMStateStore = mock(NMStateStoreService.class);
+
+ Configuration conf = new Configuration();
+
+ Context nmctx = mock(Context.class);
+ when(nmctx.getNMStateStore()).thenReturn(mockNMStateStore);
+ when(nmctx.getConf()).thenReturn(conf);
+ runningContainersMap = new ConcurrentHashMap<>();
+ when(nmctx.getContainers()).thenReturn(runningContainersMap);
+
+ gpuDiscoverer = new GpuDiscoverer();
+ gpuResourceHandler = new GpuResourceHandlerImpl(nmctx, mockCGroupsHandler,
+ mockPrivilegedExecutor, gpuDiscoverer);
+ }
+
+ @After
+ public void cleanupTestFiles() throws IOException {
+ FileUtils.deleteDirectory(testDataDirectory);
+ }
+
+ @Test
+ public void testBootStrap() throws Exception {
+ Configuration conf = createDefaultConfig();
+ conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0");
+
+ gpuDiscoverer.initialize(conf);
+
+ gpuResourceHandler.bootstrap(conf);
+ verify(mockCGroupsHandler, times(1)).initializeCGroupController(
+ CGroupsHandler.CGroupController.DEVICES);
+ }
+
+ private static ContainerId getContainerId(int id) {
+ return ContainerId.newContainerId(ApplicationAttemptId
+ .newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
+ }
+
+ private static Container mockContainerWithGpuRequest(int id, int numGpuRequest,
+ boolean dockerContainerEnabled) {
+ Container c = mock(Container.class);
+ when(c.getContainerId()).thenReturn(getContainerId(id));
+
+ Resource res = Resource.newInstance(1024, 1);
+ ResourceMappings resMapping = new ResourceMappings();
+
+ res.setResourceValue(ResourceInformation.GPU_URI, numGpuRequest);
+ when(c.getResource()).thenReturn(res);
+ when(c.getResourceMappings()).thenReturn(resMapping);
+
+ ContainerLaunchContext clc = mock(ContainerLaunchContext.class);
+ Map env = new HashMap<>();
+ if (dockerContainerEnabled) {
+ env.put(ContainerRuntimeConstants.ENV_CONTAINER_TYPE,
+ ContainerRuntimeConstants.CONTAINER_RUNTIME_DOCKER);
+ }
+ when(clc.getEnvironment()).thenReturn(env);
+ when(c.getLaunchContext()).thenReturn(clc);
+ return c;
+ }
+
+ private static Container mockContainerWithGpuRequest(int id,
+ int numGpuRequest) {
+ return mockContainerWithGpuRequest(id, numGpuRequest, false);
+ }
+
+ private void verifyDeniedDevices(ContainerId containerId,
+ List deniedDevices)
+ throws ResourceHandlerException, PrivilegedOperationException {
+ verify(mockCGroupsHandler, times(1)).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES, containerId.toString());
+
+ if (null != deniedDevices && !deniedDevices.isEmpty()) {
+ List deniedDevicesMinorNumber = new ArrayList<>();
+ for (GpuDevice deniedDevice : deniedDevices) {
+ deniedDevicesMinorNumber.add(deniedDevice.getMinorNumber());
+ }
+ verify(mockPrivilegedExecutor, times(1)).executePrivilegedOperation(
+ new PrivilegedOperation(PrivilegedOperation.OperationType.GPU, Arrays
+ .asList(GpuResourceHandlerImpl.CONTAINER_ID_CLI_OPTION,
+ containerId.toString(),
+ GpuResourceHandlerImpl.EXCLUDED_GPUS_CLI_OPTION,
+ StringUtils.join(",", deniedDevicesMinorNumber))), true);
+ }
+ }
+
+ private void commonTestAllocation(boolean dockerContainerEnabled)
+ throws Exception {
+ Configuration conf = createDefaultConfig();
+ conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:3,3:4");
+ gpuDiscoverer.initialize(conf);
+
+ gpuResourceHandler.bootstrap(conf);
+ Assert.assertEquals(4,
+ gpuResourceHandler.getGpuAllocator().getAvailableGpus());
+
+ /* Start container 1, asks 3 containers */
+ gpuResourceHandler.preStart(
+ mockContainerWithGpuRequest(1, 3, dockerContainerEnabled));
+
+ // Only device=4 will be blocked.
+ if (dockerContainerEnabled) {
+ verifyDeniedDevices(getContainerId(1), Collections.emptyList());
+ } else{
+ verifyDeniedDevices(getContainerId(1), Arrays.asList(new GpuDevice(3,4)));
+ }
+
+ /* Start container 2, asks 2 containers. Excepted to fail */
+ boolean failedToAllocate = false;
+ try {
+ gpuResourceHandler.preStart(
+ mockContainerWithGpuRequest(2, 2, dockerContainerEnabled));
+ } catch (ResourceHandlerException e) {
+ failedToAllocate = true;
+ }
+ Assert.assertTrue(failedToAllocate);
+
+ /* Start container 3, ask 1 container, succeeded */
+ gpuResourceHandler.preStart(
+ mockContainerWithGpuRequest(3, 1, dockerContainerEnabled));
+
+ // devices = 0/1/3 will be blocked
+ if (dockerContainerEnabled) {
+ verifyDeniedDevices(getContainerId(3), Collections.emptyList());
+ } else {
+ verifyDeniedDevices(getContainerId(3), Arrays
+ .asList(new GpuDevice(0, 0), new GpuDevice(1, 1),
+ new GpuDevice(2, 3)));
+ }
+
+
+ /* Start container 4, ask 0 container, succeeded */
+ gpuResourceHandler.preStart(
+ mockContainerWithGpuRequest(4, 0, dockerContainerEnabled));
+
+ if (dockerContainerEnabled) {
+ verifyDeniedDevices(getContainerId(4), Collections.emptyList());
+ } else{
+ // All devices will be blocked
+ verifyDeniedDevices(getContainerId(4), Arrays
+ .asList(new GpuDevice(0, 0), new GpuDevice(1, 1), new GpuDevice(2, 3),
+ new GpuDevice(3, 4)));
+ }
+
+ /* Release container-1, expect cgroups deleted */
+ gpuResourceHandler.postComplete(getContainerId(1));
+
+ verify(mockCGroupsHandler, times(1)).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES, getContainerId(1).toString());
+ Assert.assertEquals(3,
+ gpuResourceHandler.getGpuAllocator().getAvailableGpus());
+
+ /* Release container-3, expect cgroups deleted */
+ gpuResourceHandler.postComplete(getContainerId(3));
+
+ verify(mockCGroupsHandler, times(1)).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES, getContainerId(3).toString());
+ Assert.assertEquals(4,
+ gpuResourceHandler.getGpuAllocator().getAvailableGpus());
+ }
+
+ @Test
+ public void testAllocationWhenDockerContainerEnabled() throws Exception {
+ // When docker container is enabled, no devices should be written to
+ // devices.deny.
+ commonTestAllocation(true);
+ }
+
+ @Test
+ public void testAllocation() throws Exception {
+ commonTestAllocation(false);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testAssignedGpuWillBeCleanedupWhenStoreOpFails()
+ throws Exception {
+ Configuration conf = createDefaultConfig();
+ conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:3,3:4");
+ gpuDiscoverer.initialize(conf);
+
+ gpuResourceHandler.bootstrap(conf);
+ Assert.assertEquals(4,
+ gpuResourceHandler.getGpuAllocator().getAvailableGpus());
+
+ doThrow(new IOException("Exception ...")).when(mockNMStateStore)
+ .storeAssignedResources(
+ any(Container.class), anyString(), anyList());
+
+ boolean exception = false;
+ /* Start container 1, asks 3 containers */
+ try {
+ gpuResourceHandler.preStart(mockContainerWithGpuRequest(1, 3));
+ } catch (ResourceHandlerException e) {
+ exception = true;
+ }
+
+ Assert.assertTrue("preStart should throw exception", exception);
+
+ // After preStart, we still have 4 available GPU since the store op fails.
+ Assert.assertEquals(4,
+ gpuResourceHandler.getGpuAllocator().getAvailableGpus());
+ }
+
+ @Test
+ public void testAllocationWithoutAllowedGpus() throws Exception {
+ Configuration conf = createDefaultConfig();
+ conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, " ");
+ gpuDiscoverer.initialize(conf);
+
+ try {
+ gpuResourceHandler.bootstrap(conf);
+ Assert.fail("Should fail because no GPU available");
+ } catch (ResourceHandlerException e) {
+ // Expected because of no resource available
+ }
+
+ /* Start container 1, asks 0 containers */
+ gpuResourceHandler.preStart(mockContainerWithGpuRequest(1, 0));
+ verifyDeniedDevices(getContainerId(1), Collections.emptyList());
+
+ /* Start container 2, asks 1 containers. Excepted to fail */
+ boolean failedToAllocate = false;
+ try {
+ gpuResourceHandler.preStart(mockContainerWithGpuRequest(2, 1));
+ } catch (ResourceHandlerException e) {
+ failedToAllocate = true;
+ }
+ Assert.assertTrue(failedToAllocate);
+
+ /* Release container 1, expect cgroups deleted */
+ gpuResourceHandler.postComplete(getContainerId(1));
+
+ verify(mockCGroupsHandler, times(1)).createCGroup(
+ CGroupsHandler.CGroupController.DEVICES, getContainerId(1).toString());
+ Assert.assertEquals(0,
+ gpuResourceHandler.getGpuAllocator().getAvailableGpus());
+ }
+
+ @Test
+ public void testAllocationStored() throws Exception {
+ Configuration conf = createDefaultConfig();
+ conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:3,3:4");
+ gpuDiscoverer.initialize(conf);
+
+ gpuResourceHandler.bootstrap(conf);
+ Assert.assertEquals(4,
+ gpuResourceHandler.getGpuAllocator().getAvailableGpus());
+
+ /* Start container 1, asks 3 containers */
+ Container container = mockContainerWithGpuRequest(1, 3);
+ gpuResourceHandler.preStart(container);
+
+ verify(mockNMStateStore).storeAssignedResources(container,
+ ResourceInformation.GPU_URI, Arrays
+ .asList(new GpuDevice(0, 0), new GpuDevice(1, 1),
+ new GpuDevice(2, 3)));
+
+ // Only device=4 will be blocked.
+ verifyDeniedDevices(getContainerId(1), Arrays.asList(new GpuDevice(3, 4)));
+
+ /* Start container 2, ask 0 container, succeeded */
+ container = mockContainerWithGpuRequest(2, 0);
+ gpuResourceHandler.preStart(container);
+
+ verifyDeniedDevices(getContainerId(2), Arrays
+ .asList(new GpuDevice(0, 0), new GpuDevice(1, 1), new GpuDevice(2, 3),
+ new GpuDevice(3, 4)));
+ Assert.assertEquals(0, container.getResourceMappings()
+ .getAssignedResources(ResourceInformation.GPU_URI).size());
+
+ // Store assigned resource will not be invoked.
+ verify(mockNMStateStore, never()).storeAssignedResources(
+ eq(container), eq(ResourceInformation.GPU_URI), anyList());
+ }
+
+ @Test
+ public void testAllocationStoredWithNULLStateStore() throws Exception {
+ NMNullStateStoreService mockNMNULLStateStore = mock(NMNullStateStoreService.class);
+
+ Configuration conf = createDefaultConfig();
+ conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:3,3:4");
+
+ Context nmnctx = mock(Context.class);
+ when(nmnctx.getNMStateStore()).thenReturn(mockNMNULLStateStore);
+ when(nmnctx.getConf()).thenReturn(conf);
+
+ GpuResourceHandlerImpl gpuNULLStateResourceHandler =
+ new GpuResourceHandlerImpl(nmnctx, mockCGroupsHandler,
+ mockPrivilegedExecutor, gpuDiscoverer);
+
+ gpuDiscoverer.initialize(conf);
+
+ gpuNULLStateResourceHandler.bootstrap(conf);
+ Assert.assertEquals(4,
+ gpuNULLStateResourceHandler.getGpuAllocator().getAvailableGpus());
+
+ /* Start container 1, asks 3 containers */
+ Container container = mockContainerWithGpuRequest(1, 3);
+ gpuNULLStateResourceHandler.preStart(container);
+
+ verify(nmnctx.getNMStateStore()).storeAssignedResources(container,
+ ResourceInformation.GPU_URI, Arrays
+ .asList(new GpuDevice(0, 0), new GpuDevice(1, 1),
+ new GpuDevice(2, 3)));
+ }
+
+ @Test
+ public void testRecoverResourceAllocation() throws Exception {
+ Configuration conf = createDefaultConfig();
+ conf.set(YarnConfiguration.NM_GPU_ALLOWED_DEVICES, "0:0,1:1,2:3,3:4");
+ gpuDiscoverer.initialize(conf);
+
+ gpuResourceHandler.bootstrap(conf);
+ Assert.assertEquals(4,
+ gpuResourceHandler.getGpuAllocator().getAvailableGpus());
+
+ Container nmContainer = mock(Container.class);
+ ResourceMappings rmap = new ResourceMappings();
+ ResourceMappings.AssignedResources ar =
+ new ResourceMappings.AssignedResources();
+ ar.updateAssignedResources(
+ Arrays.asList(new GpuDevice(1, 1), new GpuDevice(2, 3)));
+ rmap.addAssignedResources(ResourceInformation.GPU_URI, ar);
+ when(nmContainer.getResourceMappings()).thenReturn(rmap);
+
+ runningContainersMap.put(getContainerId(1), nmContainer);
+
+ // TEST CASE
+ // Reacquire container restore state of GPU Resource Allocator.
+ gpuResourceHandler.reacquireContainer(getContainerId(1));
+
+ Map deviceAllocationMapping =
+ gpuResourceHandler.getGpuAllocator().getDeviceAllocationMappingCopy();
+ Assert.assertEquals(2, deviceAllocationMapping.size());
+ Assert.assertTrue(
+ deviceAllocationMapping.keySet().contains(new GpuDevice(1, 1)));
+ Assert.assertTrue(
+ deviceAllocationMapping.keySet().contains(new GpuDevice(2, 3)));
+ Assert.assertEquals(deviceAllocationMapping.get(new GpuDevice(1, 1)),
+ getContainerId(1));
+
+ // TEST CASE
+ // Try to reacquire a container but requested device is not in allowed list.
+ nmContainer = mock(Container.class);
+ rmap = new ResourceMappings();
+ ar = new ResourceMappings.AssignedResources();
+ // id=5 is not in allowed list.
+ ar.updateAssignedResources(
+ Arrays.asList(new GpuDevice(3, 4), new GpuDevice(4, 5)));
+ rmap.addAssignedResources(ResourceInformation.GPU_URI, ar);
+ when(nmContainer.getResourceMappings()).thenReturn(rmap);
+
+ runningContainersMap.put(getContainerId(2), nmContainer);
+
+ boolean caughtException = false;
+ try {
+ gpuResourceHandler.reacquireContainer(getContainerId(1));
+ } catch (ResourceHandlerException e) {
+ caughtException = true;
+ }
+ Assert.assertTrue(
+ "Should fail since requested device Id is not in allowed list",
+ caughtException);
+
+ // Make sure internal state not changed.
+ deviceAllocationMapping =
+ gpuResourceHandler.getGpuAllocator().getDeviceAllocationMappingCopy();
+ Assert.assertEquals(2, deviceAllocationMapping.size());
+ Assert.assertTrue(deviceAllocationMapping.keySet()
+ .containsAll(Arrays.asList(new GpuDevice(1, 1), new GpuDevice(2, 3))));
+ Assert.assertEquals(deviceAllocationMapping.get(new GpuDevice(1, 1)),
+ getContainerId(1));
+
+ // TEST CASE
+ // Try to reacquire a container but requested device is already assigned.
+ nmContainer = mock(Container.class);
+ rmap = new ResourceMappings();
+ ar = new ResourceMappings.AssignedResources();
+ // id=3 is already assigned
+ ar.updateAssignedResources(
+ Arrays.asList(new GpuDevice(3, 4), new GpuDevice(2, 3)));
+ rmap.addAssignedResources("gpu", ar);
+ when(nmContainer.getResourceMappings()).thenReturn(rmap);
+
+ runningContainersMap.put(getContainerId(2), nmContainer);
+
+ caughtException = false;
+ try {
+ gpuResourceHandler.reacquireContainer(getContainerId(1));
+ } catch (ResourceHandlerException e) {
+ caughtException = true;
+ }
+ Assert.assertTrue(
+ "Should fail since requested device Id is already assigned",
+ caughtException);
+
+ // Make sure internal state not changed.
+ deviceAllocationMapping =
+ gpuResourceHandler.getGpuAllocator().getDeviceAllocationMappingCopy();
+ Assert.assertEquals(2, deviceAllocationMapping.size());
+ Assert.assertTrue(deviceAllocationMapping.keySet()
+ .containsAll(Arrays.asList(new GpuDevice(1, 1), new GpuDevice(2, 3))));
+ Assert.assertEquals(deviceAllocationMapping.get(new GpuDevice(1, 1)),
+ getContainerId(1));
+ }
+}